From dccd5c2b1acdee242d847e1fbeea2edabfdde15f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 9 Nov 2022 18:24:07 +0100 Subject: [PATCH] feat(server): Clarify CausalLMBatch concatenate method --- Cargo.lock | 52 ++++----- server/text_generation/models/causal_lm.py | 118 +++++++++++---------- 2 files changed, 86 insertions(+), 84 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8f09c3a5..0fd5c4bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,9 +213,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.73" +version = "1.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +checksum = "76a284da2e6fe2092f2353e51713435363112dfd60030e22add80be333fb928f" [[package]] name = "cfg-if" @@ -240,9 +240,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.0.18" +version = "4.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b" +checksum = "91b9970d7505127a162fdaa9b96428d28a479ba78c9ec7550a63a5d9863db682" dependencies = [ "atty", "bitflags", @@ -255,9 +255,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.0.18" +version = "4.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3" +checksum = "0177313f9f02afc995627906bbd8967e2be069f5261954222dac78290c2b9014" dependencies = [ "heck 0.4.0", "proc-macro-error", @@ -790,9 +790,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.20" +version = "0.14.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac" +checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c" dependencies = [ "bytes", "futures-channel", @@ -898,9 +898,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.5.0" +version = "2.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" +checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745" [[package]] name = "itertools" @@ -1053,9 +1053,9 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" [[package]] name = "native-tls" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd7e2f3618557f980e0b17e8856252eee3c97fa12c54dff0ca290fb6266ca4a9" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" dependencies = [ "lazy_static", "libc", @@ -1103,9 +1103,9 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" +checksum = "f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5" dependencies = [ "hermit-abi", "libc", @@ -1125,9 +1125,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "once_cell" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" +checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" [[package]] name = "onig" @@ -1293,9 +1293,9 @@ checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" [[package]] name = "ppv-lite86" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro-error" @@ -1479,9 +1479,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b" +checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a" dependencies = [ "aho-corasick", "memchr", @@ -1490,9 +1490,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.27" +version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" [[package]] name = "remove_dir_all" @@ -1802,7 +1802,7 @@ dependencies = [ name = "text-generation-launcher" version = "0.1.0" dependencies = [ - "clap 4.0.18", + "clap 4.0.22", "ctrlc", "subprocess", "tracing", @@ -1814,7 +1814,7 @@ name = "text-generation-router" version = "0.1.0" dependencies = [ "axum", - "clap 4.0.18", + "clap 4.0.22", "futures", "parking_lot", "serde", @@ -1893,9 +1893,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokenizers" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170" +checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69" dependencies = [ "aho-corasick", "cached-path", diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 2ba36b1a..c1057635 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -148,71 +148,73 @@ class CausalLMBatch: ] = batch.attention_mask[:, -batch.max_sequence_length :] for j, past in enumerate(batch.past_key_values): + past_keys, past_values = past + # Shenanigans to get dimensions because BLOOM outputs a past with a different shape - # BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...] - head_dim, padded_sequence_length = past[0].shape[-2:] - num_heads = ( - past[0] - .view(batch.size, -1, head_dim, padded_sequence_length) - .shape[1] + # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] + # BLOOM Values: [batch_size * num_heads, seq_length, head_dim] + past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:]) + past_values = past_values.view(batch.size, -1, *past_values.shape[-2:]) + + _, num_heads, head_dim, padded_sequence_length = past_keys.shape + + padded_past_keys_shape = ( + total_batch_size, + num_heads, + head_dim, + max_sequence_length - 1, ) + # head_dim is last for BLOOM + if past_values.shape[-1] == head_dim: + past_values_head_dim_last = True + padded_past_values_shape = ( + total_batch_size, + num_heads, + max_sequence_length - 1, + head_dim, + ) + elif past_values.shape[-2] == head_dim: + past_values_head_dim_last = False + padded_past_values_shape = padded_past_keys_shape + else: + raise ValueError( + f"past_values shape {past_values.shape} is not valid" + ) + # This will run only once per layer if j == len(past_key_values): - past_key_values.append([]) + padded_past_keys = torch.zeros( + padded_past_keys_shape, + dtype=past_keys.dtype, + device=past_keys.device, + ) + padded_past_values = torch.zeros( + padded_past_values_shape, + dtype=past_values.dtype, + device=past_values.device, + ) + past_key_values.append((padded_past_keys, padded_past_values)) - # Decoder past - for k, t in enumerate(past): - # Needed because BLOOM past shapes are not the same for keys and values - # Keys: [batch_size * num_heads, head_dim, seq_length] - # Values: [batch_size * num_heads, seq_length, head_dim] - head_dim_last = False - if t.shape[-2] == head_dim: - t = t.view( - batch.size, num_heads, head_dim, padded_sequence_length - ) - padded_t_shape = ( - total_batch_size, - num_heads, - head_dim, - max_sequence_length - 1, - ) - elif t.shape[-1] == head_dim: - head_dim_last = True - t = t.view( - batch.size, num_heads, padded_sequence_length, head_dim - ) - padded_t_shape = ( - total_batch_size, - num_heads, - max_sequence_length - 1, - head_dim, - ) - else: - raise ValueError(f"shape {t.shape} is not valid") + # We slice the past keys and values to remove the padding from previous batches + past_key_values[j][0][ + start_index:end_index, :, :, -(batch.max_sequence_length - 1) : + ] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :] - # Initialize tensors - # This will run only once per layer and per past tensor - if k == len(past_key_values[j]): - past_key_values[j].append( - torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) - ) - - # We slice the past keys and values to remove the padding from previous batches - if not head_dim_last: - past_key_values[j][k][ - start_index:end_index, - :, - :, - -(batch.max_sequence_length - 1) :, - ] = t[:, :, :, -(batch.max_sequence_length - 1) :] - else: - past_key_values[j][k][ - start_index:end_index, - :, - -(batch.max_sequence_length - 1) :, - :, - ] = t[:, :, -(batch.max_sequence_length - 1) :, :] + if past_values_head_dim_last: + past_key_values[j][1][ + start_index:end_index, + :, + -(batch.max_sequence_length - 1) :, + :, + ] = past_values[:, :, -(batch.max_sequence_length - 1) :, :] + else: + past_key_values[j][1][ + start_index:end_index, + :, + :, + -(batch.max_sequence_length - 1) :, + ] = past_values[:, :, :, -(batch.max_sequence_length - 1) :] start_index += batch.size