feat(server): Clarify CausalLMBatch concatenate method

This commit is contained in:
OlivierDehaene 2022-11-09 18:24:07 +01:00
parent fa43fb71be
commit dccd5c2b1a
2 changed files with 86 additions and 84 deletions

52
Cargo.lock generated
View File

@ -213,9 +213,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.0.73"
version = "1.0.76"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
checksum = "76a284da2e6fe2092f2353e51713435363112dfd60030e22add80be333fb928f"
[[package]]
name = "cfg-if"
@ -240,9 +240,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.0.18"
version = "4.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b"
checksum = "91b9970d7505127a162fdaa9b96428d28a479ba78c9ec7550a63a5d9863db682"
dependencies = [
"atty",
"bitflags",
@ -255,9 +255,9 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.0.18"
version = "4.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3"
checksum = "0177313f9f02afc995627906bbd8967e2be069f5261954222dac78290c2b9014"
dependencies = [
"heck 0.4.0",
"proc-macro-error",
@ -790,9 +790,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]]
name = "hyper"
version = "0.14.20"
version = "0.14.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac"
checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c"
dependencies = [
"bytes",
"futures-channel",
@ -898,9 +898,9 @@ dependencies = [
[[package]]
name = "ipnet"
version = "2.5.0"
version = "2.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b"
checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745"
[[package]]
name = "itertools"
@ -1053,9 +1053,9 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "native-tls"
version = "0.2.10"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd7e2f3618557f980e0b17e8856252eee3c97fa12c54dff0ca290fb6266ca4a9"
checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"
dependencies = [
"lazy_static",
"libc",
@ -1103,9 +1103,9 @@ dependencies = [
[[package]]
name = "num_cpus"
version = "1.13.1"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
checksum = "f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5"
dependencies = [
"hermit-abi",
"libc",
@ -1125,9 +1125,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "once_cell"
version = "1.15.0"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
[[package]]
name = "onig"
@ -1293,9 +1293,9 @@ checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "ppv-lite86"
version = "0.2.16"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro-error"
@ -1479,9 +1479,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.6.0"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b"
checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a"
dependencies = [
"aho-corasick",
"memchr",
@ -1490,9 +1490,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.6.27"
version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244"
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
[[package]]
name = "remove_dir_all"
@ -1802,7 +1802,7 @@ dependencies = [
name = "text-generation-launcher"
version = "0.1.0"
dependencies = [
"clap 4.0.18",
"clap 4.0.22",
"ctrlc",
"subprocess",
"tracing",
@ -1814,7 +1814,7 @@ name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"clap 4.0.18",
"clap 4.0.22",
"futures",
"parking_lot",
"serde",
@ -1893,9 +1893,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokenizers"
version = "0.13.1"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170"
checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69"
dependencies = [
"aho-corasick",
"cached-path",

View File

@ -148,71 +148,73 @@ class CausalLMBatch:
] = batch.attention_mask[:, -batch.max_sequence_length :]
for j, past in enumerate(batch.past_key_values):
past_keys, past_values = past
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim, padded_sequence_length = past[0].shape[-2:]
num_heads = (
past[0]
.view(batch.size, -1, head_dim, padded_sequence_length)
.shape[1]
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
_, num_heads, head_dim, padded_sequence_length = past_keys.shape
padded_past_keys_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
# head_dim is last for BLOOM
if past_values.shape[-1] == head_dim:
past_values_head_dim_last = True
padded_past_values_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
)
elif past_values.shape[-2] == head_dim:
past_values_head_dim_last = False
padded_past_values_shape = padded_past_keys_shape
else:
raise ValueError(
f"past_values shape {past_values.shape} is not valid"
)
# This will run only once per layer
if j == len(past_key_values):
past_key_values.append([])
padded_past_keys = torch.zeros(
padded_past_keys_shape,
dtype=past_keys.dtype,
device=past_keys.device,
)
padded_past_values = torch.zeros(
padded_past_values_shape,
dtype=past_values.dtype,
device=past_values.device,
)
past_key_values.append((padded_past_keys, padded_past_values))
# Decoder past
for k, t in enumerate(past):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last = False
if t.shape[-2] == head_dim:
t = t.view(
batch.size, num_heads, head_dim, padded_sequence_length
)
padded_t_shape = (
total_batch_size,
num_heads,
head_dim,
max_sequence_length - 1,
)
elif t.shape[-1] == head_dim:
head_dim_last = True
t = t.view(
batch.size, num_heads, padded_sequence_length, head_dim
)
padded_t_shape = (
total_batch_size,
num_heads,
max_sequence_length - 1,
head_dim,
)
else:
raise ValueError(f"shape {t.shape} is not valid")
# We slice the past keys and values to remove the padding from previous batches
past_key_values[j][0][
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
# Initialize tensors
# This will run only once per layer and per past tensor
if k == len(past_key_values[j]):
past_key_values[j].append(
torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device)
)
# We slice the past keys and values to remove the padding from previous batches
if not head_dim_last:
past_key_values[j][k][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = t[:, :, :, -(batch.max_sequence_length - 1) :]
else:
past_key_values[j][k][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
:,
] = t[:, :, -(batch.max_sequence_length - 1) :, :]
if past_values_head_dim_last:
past_key_values[j][1][
start_index:end_index,
:,
-(batch.max_sequence_length - 1) :,
:,
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
else:
past_key_values[j][1][
start_index:end_index,
:,
:,
-(batch.max_sequence_length - 1) :,
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
start_index += batch.size