2022-10-28 11:24:00 -06:00
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
from typing import Optional, Type
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2023-01-20 04:24:39 -07:00
|
|
|
from transformers import (
|
|
|
|
AutoTokenizer,
|
|
|
|
AutoConfig,
|
|
|
|
PreTrainedTokenizerBase,
|
|
|
|
)
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
|
|
|
BloomForCausalLM,
|
|
|
|
)
|
2023-03-07 10:52:22 -07:00
|
|
|
from text_generation_server.models import CausalLM
|
|
|
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
|
|
|
from text_generation_server.pb import generate_pb2
|
|
|
|
from text_generation_server.utils import (
|
2022-10-28 11:24:00 -06:00
|
|
|
initialize_torch_distributed,
|
|
|
|
weight_files,
|
2023-06-08 06:51:52 -06:00
|
|
|
Weights,
|
2022-10-28 11:24:00 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-12-08 10:49:33 -07:00
|
|
|
class BloomCausalLMBatch(CausalLMBatch):
|
|
|
|
@classmethod
|
|
|
|
def from_pb(
|
2023-01-20 04:24:39 -07:00
|
|
|
cls,
|
|
|
|
pb: generate_pb2.Batch,
|
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
2023-05-26 04:30:27 -06:00
|
|
|
dtype: torch.dtype,
|
2023-01-20 04:24:39 -07:00
|
|
|
device: torch.device,
|
2022-12-08 10:49:33 -07:00
|
|
|
) -> "CausalLMBatch":
|
2023-06-08 06:51:52 -06:00
|
|
|
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
2022-12-08 10:49:33 -07:00
|
|
|
batch.keys_head_dim_last = False
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
class BLOOMSharded(CausalLM):
|
2023-01-31 10:53:56 -07:00
|
|
|
def __init__(
|
2023-05-12 06:46:41 -06:00
|
|
|
self,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
2024-05-14 04:33:18 -06:00
|
|
|
speculator: Optional[str] = None,
|
2023-06-30 12:30:09 -06:00
|
|
|
dtype: Optional[torch.dtype] = None,
|
2023-05-23 12:40:39 -06:00
|
|
|
trust_remote_code: bool = False,
|
2023-01-31 10:53:56 -07:00
|
|
|
):
|
MI300 compatibility (#1764)
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
2024-05-17 07:30:47 -06:00
|
|
|
|
2023-05-10 07:48:21 -06:00
|
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
2022-10-28 11:24:00 -06:00
|
|
|
if torch.cuda.is_available():
|
2023-05-10 07:48:21 -06:00
|
|
|
device = torch.device(f"cuda:{rank}")
|
2023-06-30 12:30:09 -06:00
|
|
|
dtype = torch.float16 if dtype is None else dtype
|
2022-10-28 11:24:00 -06:00
|
|
|
else:
|
2022-11-04 07:22:47 -06:00
|
|
|
device = torch.device("cpu")
|
2023-09-19 09:19:28 -06:00
|
|
|
dtype = torch.float32 if dtype is None else dtype
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2023-01-31 10:53:56 -07:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
2023-05-23 12:40:39 -06:00
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
padding_side="left",
|
|
|
|
truncation_side="left",
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-01-31 10:53:56 -07:00
|
|
|
)
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(
|
2023-05-23 12:40:39 -06:00
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
slow_but_exact=False,
|
|
|
|
tp_parallel=True,
|
|
|
|
trust_remote_code=trust_remote_code,
|
2022-10-28 11:24:00 -06:00
|
|
|
)
|
|
|
|
config.pad_token_id = 3
|
2023-06-08 06:51:52 -06:00
|
|
|
config.quantize = quantize
|
2024-05-14 04:33:18 -06:00
|
|
|
config.speculator = speculator
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
2023-02-03 04:43:37 -07:00
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
2023-06-08 06:51:52 -06:00
|
|
|
weights = Weights(
|
2023-12-11 06:49:52 -07:00
|
|
|
filenames,
|
|
|
|
device=device,
|
|
|
|
dtype=dtype,
|
|
|
|
process_group=self.process_group,
|
|
|
|
prefix="transformer",
|
2023-06-08 06:51:52 -06:00
|
|
|
)
|
2024-06-14 01:45:42 -06:00
|
|
|
if config.quantize in ["gptq", "marlin"]:
|
2023-12-14 03:02:16 -07:00
|
|
|
weights._set_gptq_params(model_id, revision)
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
model = BloomForCausalLM(config, weights)
|
2022-10-28 11:24:00 -06:00
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
2022-11-04 11:03:04 -06:00
|
|
|
super(CausalLM, self).__init__(
|
2023-05-16 15:23:27 -06:00
|
|
|
model=model,
|
2023-04-21 07:36:29 -06:00
|
|
|
tokenizer=tokenizer,
|
|
|
|
requires_padding=True,
|
|
|
|
dtype=dtype,
|
|
|
|
device=device,
|
2023-05-10 07:48:21 -06:00
|
|
|
rank=rank,
|
|
|
|
world_size=world_size,
|
2022-11-04 11:03:04 -06:00
|
|
|
)
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
@property
|
|
|
|
def batch_type(self) -> Type[CausalLMBatch]:
|
|
|
|
return BloomCausalLMBatch
|
2022-10-28 11:24:00 -06:00
|
|
|
|
2023-01-30 07:36:16 -07:00
|
|
|
def forward(
|
|
|
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
|
|
|
):
|
2024-02-26 11:49:28 -07:00
|
|
|
outputs, speculative_logits = self.model.forward(
|
2022-10-28 11:24:00 -06:00
|
|
|
input_ids=input_ids,
|
|
|
|
attention_mask=attention_mask,
|
2023-01-20 07:35:22 -07:00
|
|
|
position_ids=position_ids,
|
2022-10-28 11:24:00 -06:00
|
|
|
past_key_values=past_key_values,
|
|
|
|
use_cache=True,
|
|
|
|
)
|
|
|
|
|
2023-06-08 06:51:52 -06:00
|
|
|
logits = outputs.logits
|
2024-02-26 11:49:28 -07:00
|
|
|
return logits, speculative_logits, outputs.past_key_values
|