Single place for TP layers + Dropout Layer Norm + FastLinear (#329)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2023-05-15 17:30:47 +02:00 committed by GitHub
parent 66b277321d
commit f58f0a0364
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 303 additions and 698 deletions

View File

@ -28,17 +28,16 @@ from transformers.activations import ACT2FN
from typing import Optional
# Flash attention imports
import rotary_emb
import flash_attn_cuda
import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
HAS_BITS_AND_BYTES = False
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
)
class LlamaRMSNorm(nn.Module):
@ -91,216 +90,6 @@ class LlamaRMSNorm(nn.Module):
return normed_hidden_states, res
class FastLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
return self.bnb_linear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
class PositionRotaryEmbedding(RotaryEmbedding):
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
"""
Return cos and sin for the asked position ids
"""
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,

View File

@ -30,265 +30,17 @@ from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional
# Flash attention imports
import rotary_emb
import flash_attn_cuda
import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
HAS_BITS_AND_BYTES = False
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
class FastLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
return self.bnb_linear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
class PositionRotaryEmbedding(RotaryEmbedding):
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
"""
Return cos and sin for the asked position ids
"""
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
FastLayerNorm,
PositionRotaryEmbedding,
)
class FlashNeoxAttention(torch.nn.Module):

View File

@ -9,224 +9,13 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
HAS_BITS_AND_BYTES = False
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
class FastLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
return self.bnb_linear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
reduce=True,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.reduce = reduce
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
FastLayerNorm,
)
class FlashMQAttention(torch.nn.Module):

View File

@ -10,23 +10,26 @@ from transformers import (
AutoModelForSeq2SeqLM,
AutoConfig,
)
from transformers.models.t5.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)
from text_generation_server.utils.layers import (
FastLinear,
)
from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
except ImportError as e:
HAS_BITS_AND_BYTES = False

View File

@ -0,0 +1,272 @@
import torch
from torch import nn
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
HAS_BITS_AND_BYTES = False
class FastLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
return self.bnb_linear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
try:
import dropout_layer_norm
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
except ImportError:
pass
try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
class PositionRotaryEmbedding(RotaryEmbedding):
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(
self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype
):
"""
Return cos and sin for the asked position ids
"""
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
except ImportError:
pass