Removing IPEX_AVAIL. (#2115)
* Removing IPEX_AVAIL. Chose to unify CPU and XPU under `ipex`. Most code is exactly similar except for a very few spots. The biggest number of spots is the kv-cache layout and the flash_xxx.py files. Since those files should be removed soon and factored away, we should not need them. * Forgot a few places. * Unrelated change. * Fixing HF_TOKEN. * HF_TOKEN
This commit is contained in:
parent
3f3b7ffd67
commit
9e2fdf57c0
|
@ -178,6 +178,6 @@ jobs:
|
||||||
export DOCKER_VOLUME=/mnt/cache
|
export DOCKER_VOLUME=/mnt/cache
|
||||||
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
|
||||||
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
|
||||||
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
echo $DOCKER_IMAGE
|
echo $DOCKER_IMAGE
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
|
|
@ -22,5 +22,5 @@ jobs:
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest pytest-asyncio
|
pip install pytest pytest-asyncio
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
make python-client-tests
|
make python-client-tests
|
||||||
|
|
|
@ -37,5 +37,5 @@ jobs:
|
||||||
export DOCKER_VOLUME=/mnt/cache
|
export DOCKER_VOLUME=/mnt/cache
|
||||||
export DOCKER_IMAGE=${{ inputs.docker_image }}
|
export DOCKER_IMAGE=${{ inputs.docker_image }}
|
||||||
export DOCKER_DEVICES=${{ inputs.docker_devices }}
|
export DOCKER_DEVICES=${{ inputs.docker_devices }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
|
|
@ -28,7 +28,7 @@ jobs:
|
||||||
|
|
||||||
- name: Start starcoder
|
- name: Start starcoder
|
||||||
run: |
|
run: |
|
||||||
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
|
||||||
sleep 10
|
sleep 10
|
||||||
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ jobs:
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
run: |
|
run: |
|
||||||
pip install pytest
|
pip install pytest
|
||||||
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HF_TOKEN=${{ secrets.HF_TOKEN }}
|
||||||
pytest -s -vv server/tests
|
pytest -s -vv server/tests
|
||||||
- name: Pre-commit checks
|
- name: Pre-commit checks
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -455,6 +455,6 @@ class DeployedModel(BaseModel):
|
||||||
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
|
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
|
||||||
# with model_ prefixes, since this disables guardrails for colliding fields:
|
# with model_ prefixes, since this disables guardrails for colliding fields:
|
||||||
# https://github.com/pydantic/pydantic/issues/9177
|
# https://github.com/pydantic/pydantic/issues/9177
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
model_id: str
|
model_id: str
|
||||||
sha: str
|
sha: str
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
|
@ -7,7 +7,7 @@ if SYSTEM == "cuda":
|
||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif IPEX_AVAIL:
|
elif SYSTEM == "ipex":
|
||||||
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
|
|
@ -3,7 +3,6 @@ from torch import nn
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
SYSTEM,
|
SYSTEM,
|
||||||
IPEX_AVAIL,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,7 +82,7 @@ elif SYSTEM == "rocm":
|
||||||
|
|
||||||
return super().forward(hidden_states), residual
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
elif IPEX_AVAIL:
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
@ -112,7 +111,7 @@ class FastRMSNorm(nn.Module):
|
||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if IPEX_AVAIL:
|
if SYSTEM == "ipex":
|
||||||
out = ipex.llm.functional.add_rms_norm(
|
out = ipex.llm.functional.add_rms_norm(
|
||||||
residual,
|
residual,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|
|
@ -2,14 +2,14 @@ import os
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
elif IPEX_AVAIL:
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
elif IPEX_AVAIL:
|
elif SYSTEM == "ipex":
|
||||||
ipex.llm.functional.rotary_embedding(
|
ipex.llm.functional.rotary_embedding(
|
||||||
query, key, sin, cos, query.size(-1), True
|
query, key, sin, cos, query.size(-1), True
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,9 +3,9 @@ from torch.nn import functional as F
|
||||||
from typing import Iterable, List
|
from typing import Iterable, List
|
||||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if IPEX_AVAIL:
|
if SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer):
|
||||||
local_out = gather_input.T
|
local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
if IPEX_AVAIL:
|
if SYSTEM == "ipex":
|
||||||
ipex.distributed.all_gather_into_tensor(
|
ipex.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, gather_input, group=self.process_group
|
||||||
)
|
)
|
||||||
|
@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer):
|
||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
if IPEX_AVAIL:
|
if SYSTEM == "ipex":
|
||||||
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer):
|
||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
if IPEX_AVAIL:
|
if SYSTEM == "ipex":
|
||||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
@ -257,7 +257,7 @@ class TensorParallelEmbedding(torch.nn.Module):
|
||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
if IPEX_AVAIL:
|
if SYSTEM == "ipex":
|
||||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
|
@ -20,9 +20,9 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if not IPEX_AVAIL:
|
if SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
|
|
@ -24,9 +24,9 @@ import torch.distributed
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if not IPEX_AVAIL:
|
if SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
|
@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.dist import RANK
|
from text_generation_server.utils.dist import RANK
|
||||||
|
@ -768,12 +768,9 @@ class FlashCausalLM(Model):
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
if SYSTEM == "xpu":
|
x = BLOCK_SIZE // element_size
|
||||||
x = 1
|
|
||||||
else:
|
|
||||||
x = BLOCK_SIZE // element_size
|
|
||||||
|
|
||||||
if IPEX_AVAIL and SYSTEM == "cpu":
|
if SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
|
|
|
@ -15,7 +15,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -34,12 +34,12 @@ class FlashGPT2(FlashCausalLM):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IPEX_AVAIL:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from text_generation_server.utils import (
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
|
@ -34,12 +34,12 @@ class FlashLlama(FlashCausalLM):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IPEX_AVAIL:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -38,12 +38,12 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IPEX_AVAIL:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -33,12 +33,12 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IPEX_AVAIL:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -34,12 +34,12 @@ class FlashRWSharded(FlashCausalLM):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IPEX_AVAIL:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from text_generation_server.utils import (
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -37,12 +37,12 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "ipex":
|
||||||
device = torch.device(f"xpu:{rank}")
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device(f"xpu:{rank}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
elif IPEX_AVAIL:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
# Tensor Parallelism settings
|
# Tensor Parallelism settings
|
||||||
RANK = int(os.getenv("RANK", "0"))
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
|
@ -69,7 +69,7 @@ def initialize_torch_distributed():
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
if IPEX_AVAIL:
|
if SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
ipex.distributed.init_process_group(
|
ipex.distributed.init_process_group(
|
||||||
|
|
|
@ -37,7 +37,10 @@ def get_cpu_free_memory(device, memory_fraction):
|
||||||
return free_memory
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
IPEX_AVAIL = is_ipex_available()
|
def noop(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
SYSTEM = None
|
SYSTEM = None
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
SYSTEM = "rocm"
|
SYSTEM = "rocm"
|
||||||
|
@ -49,17 +52,19 @@ elif torch.version.cuda is not None and torch.cuda.is_available():
|
||||||
empty_cache = torch.cuda.empty_cache
|
empty_cache = torch.cuda.empty_cache
|
||||||
synchronize = torch.cuda.synchronize
|
synchronize = torch.cuda.synchronize
|
||||||
get_free_memory = get_cuda_free_memory
|
get_free_memory = get_cuda_free_memory
|
||||||
elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif is_ipex_available():
|
||||||
SYSTEM = "xpu"
|
SYSTEM = "ipex"
|
||||||
empty_cache = torch.xpu.empty_cache
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
synchronize = torch.xpu.synchronize
|
empty_cache = torch.xpu.empty_cache
|
||||||
get_free_memory = get_xpu_free_memory
|
synchronize = torch.xpu.synchronize
|
||||||
|
get_free_memory = get_xpu_free_memory
|
||||||
|
else:
|
||||||
|
empty_cache = noop
|
||||||
|
synchronize = noop
|
||||||
|
get_free_memory = get_cpu_free_memory
|
||||||
else:
|
else:
|
||||||
SYSTEM = "cpu"
|
SYSTEM = "cpu"
|
||||||
|
|
||||||
def noop(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = noop
|
||||||
get_free_memory = get_cpu_free_memory
|
get_free_memory = get_cpu_free_memory
|
||||||
|
|
Loading…
Reference in New Issue