feat(server): Clarify CausalLMBatch concatenate method
This commit is contained in:
parent
fa43fb71be
commit
dccd5c2b1a
|
@ -213,9 +213,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.73"
|
||||
version = "1.0.76"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
|
||||
checksum = "76a284da2e6fe2092f2353e51713435363112dfd60030e22add80be333fb928f"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
|
@ -240,9 +240,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.0.18"
|
||||
version = "4.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b"
|
||||
checksum = "91b9970d7505127a162fdaa9b96428d28a479ba78c9ec7550a63a5d9863db682"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"bitflags",
|
||||
|
@ -255,9 +255,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.0.18"
|
||||
version = "4.0.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3"
|
||||
checksum = "0177313f9f02afc995627906bbd8967e2be069f5261954222dac78290c2b9014"
|
||||
dependencies = [
|
||||
"heck 0.4.0",
|
||||
"proc-macro-error",
|
||||
|
@ -790,9 +790,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
|
|||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "0.14.20"
|
||||
version = "0.14.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac"
|
||||
checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
|
@ -898,9 +898,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.5.0"
|
||||
version = "2.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b"
|
||||
checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
|
@ -1053,9 +1053,9 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
|||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.10"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd7e2f3618557f980e0b17e8856252eee3c97fa12c54dff0ca290fb6266ca4a9"
|
||||
checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"libc",
|
||||
|
@ -1103,9 +1103,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.13.1"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
|
||||
checksum = "f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
|
@ -1125,9 +1125,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
|||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.15.0"
|
||||
version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
|
||||
checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
|
||||
|
||||
[[package]]
|
||||
name = "onig"
|
||||
|
@ -1293,9 +1293,9 @@ checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
|
|||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.16"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
|
@ -1479,9 +1479,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.6.0"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b"
|
||||
checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
|
@ -1490,9 +1490,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.27"
|
||||
version = "0.6.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244"
|
||||
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
|
||||
|
||||
[[package]]
|
||||
name = "remove_dir_all"
|
||||
|
@ -1802,7 +1802,7 @@ dependencies = [
|
|||
name = "text-generation-launcher"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap 4.0.18",
|
||||
"clap 4.0.22",
|
||||
"ctrlc",
|
||||
"subprocess",
|
||||
"tracing",
|
||||
|
@ -1814,7 +1814,7 @@ name = "text-generation-router"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"clap 4.0.18",
|
||||
"clap 4.0.22",
|
||||
"futures",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
|
@ -1893,9 +1893,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
|||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.13.1"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170"
|
||||
checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"cached-path",
|
||||
|
|
|
@ -148,71 +148,73 @@ class CausalLMBatch:
|
|||
] = batch.attention_mask[:, -batch.max_sequence_length :]
|
||||
|
||||
for j, past in enumerate(batch.past_key_values):
|
||||
past_keys, past_values = past
|
||||
|
||||
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
|
||||
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
|
||||
head_dim, padded_sequence_length = past[0].shape[-2:]
|
||||
num_heads = (
|
||||
past[0]
|
||||
.view(batch.size, -1, head_dim, padded_sequence_length)
|
||||
.shape[1]
|
||||
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
|
||||
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
|
||||
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
|
||||
|
||||
_, num_heads, head_dim, padded_sequence_length = past_keys.shape
|
||||
|
||||
padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
)
|
||||
|
||||
# head_dim is last for BLOOM
|
||||
if past_values.shape[-1] == head_dim:
|
||||
past_values_head_dim_last = True
|
||||
padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_sequence_length - 1,
|
||||
head_dim,
|
||||
)
|
||||
elif past_values.shape[-2] == head_dim:
|
||||
past_values_head_dim_last = False
|
||||
padded_past_values_shape = padded_past_keys_shape
|
||||
else:
|
||||
raise ValueError(
|
||||
f"past_values shape {past_values.shape} is not valid"
|
||||
)
|
||||
|
||||
# This will run only once per layer
|
||||
if j == len(past_key_values):
|
||||
past_key_values.append([])
|
||||
padded_past_keys = torch.zeros(
|
||||
padded_past_keys_shape,
|
||||
dtype=past_keys.dtype,
|
||||
device=past_keys.device,
|
||||
)
|
||||
padded_past_values = torch.zeros(
|
||||
padded_past_values_shape,
|
||||
dtype=past_values.dtype,
|
||||
device=past_values.device,
|
||||
)
|
||||
past_key_values.append((padded_past_keys, padded_past_values))
|
||||
|
||||
# Decoder past
|
||||
for k, t in enumerate(past):
|
||||
# Needed because BLOOM past shapes are not the same for keys and values
|
||||
# Keys: [batch_size * num_heads, head_dim, seq_length]
|
||||
# Values: [batch_size * num_heads, seq_length, head_dim]
|
||||
head_dim_last = False
|
||||
if t.shape[-2] == head_dim:
|
||||
t = t.view(
|
||||
batch.size, num_heads, head_dim, padded_sequence_length
|
||||
)
|
||||
padded_t_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
max_sequence_length - 1,
|
||||
)
|
||||
elif t.shape[-1] == head_dim:
|
||||
head_dim_last = True
|
||||
t = t.view(
|
||||
batch.size, num_heads, padded_sequence_length, head_dim
|
||||
)
|
||||
padded_t_shape = (
|
||||
total_batch_size,
|
||||
num_heads,
|
||||
max_sequence_length - 1,
|
||||
head_dim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"shape {t.shape} is not valid")
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
past_key_values[j][0][
|
||||
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
|
||||
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
|
||||
# Initialize tensors
|
||||
# This will run only once per layer and per past tensor
|
||||
if k == len(past_key_values[j]):
|
||||
past_key_values[j].append(
|
||||
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
|
||||
)
|
||||
|
||||
# We slice the past keys and values to remove the padding from previous batches
|
||||
if not head_dim_last:
|
||||
past_key_values[j][k][
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
] = t[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
else:
|
||||
past_key_values[j][k][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
:,
|
||||
] = t[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
if past_values_head_dim_last:
|
||||
past_key_values[j][1][
|
||||
start_index:end_index,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
:,
|
||||
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
|
||||
else:
|
||||
past_key_values[j][1][
|
||||
start_index:end_index,
|
||||
:,
|
||||
:,
|
||||
-(batch.max_sequence_length - 1) :,
|
||||
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
|
||||
|
||||
start_index += batch.size
|
||||
|
||||
|
|
Loading…
Reference in New Issue