Merge branch 'main' into gpt_awq_4
This commit is contained in:
commit
8c3859d153
|
@ -32,10 +32,6 @@ jobs:
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
# This is used to complete the identity challenge
|
|
||||||
# with sigstore/fulcio when running outside of PRs.
|
|
||||||
id-token: write
|
|
||||||
security-events: write
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
|
@ -39,6 +39,9 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
|
hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"]
|
||||||
uses: ./.github/workflows/build.yaml # calls the one above ^
|
uses: ./.github/workflows/build.yaml # calls the one above ^
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
packages: write
|
||||||
with:
|
with:
|
||||||
hardware: ${{ matrix.hardware }}
|
hardware: ${{ matrix.hardware }}
|
||||||
# https://github.com/actions/runner/issues/2206
|
# https://github.com/actions/runner/issues/2206
|
||||||
|
|
|
@ -35,7 +35,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
# Released on: 02 May, 2024
|
# Released on: 02 May, 2024
|
||||||
# https://releases.rs/docs/1.78.0/
|
# https://releases.rs/docs/1.78.0/
|
||||||
toolchain: 1.79.0
|
toolchain: 1.80.0
|
||||||
override: true
|
override: true
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
- name: Install Protoc
|
- name: Install Protoc
|
||||||
|
|
|
@ -19,3 +19,6 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp
|
||||||
data/
|
data/
|
||||||
load_tests/*.json
|
load_tests/*.json
|
||||||
server/fbgemmm
|
server/fbgemmm
|
||||||
|
|
||||||
|
.direnv/
|
||||||
|
.venv/
|
||||||
|
|
|
@ -77,3 +77,4 @@ docs/openapi.json:
|
||||||
- '#/paths/~1tokenize/post'
|
- '#/paths/~1tokenize/post'
|
||||||
- '#/paths/~1v1~1chat~1completions/post'
|
- '#/paths/~1v1~1chat~1completions/post'
|
||||||
- '#/paths/~1v1~1completions/post'
|
- '#/paths/~1v1~1completions/post'
|
||||||
|
- '#/paths/~1v1~1models/get'
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -29,6 +29,8 @@ tokenizers = { version = "0.19.1", features = ["http"] }
|
||||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||||
metrics = { version = "0.23.0" }
|
metrics = { version = "0.23.0" }
|
||||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||||
|
minijinja = { version = "2.2.0", features = ["json"] }
|
||||||
|
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
incremental = true
|
incremental = true
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
@ -184,6 +184,12 @@ WORKDIR /usr/src
|
||||||
COPY server/Makefile-selective-scan Makefile
|
COPY server/Makefile-selective-scan Makefile
|
||||||
RUN make build-all
|
RUN make build-all
|
||||||
|
|
||||||
|
# Build flashinfer
|
||||||
|
FROM kernel-builder AS flashinfer-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/Makefile-flashinfer Makefile
|
||||||
|
RUN make install-flashinfer
|
||||||
|
|
||||||
# Text Generation Inference base image
|
# Text Generation Inference base image
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
|
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
|
||||||
|
|
||||||
|
@ -236,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
|
||||||
# Copy build artifacts from mamba builder
|
# Copy build artifacts from mamba builder
|
||||||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# Rust builder
|
# Rust builder
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
ARG PLATFORM=xpu
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||||
|
@ -178,5 +178,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
|
ENV ATTENTION=paged
|
||||||
|
ENV USE_PREFIX_CACHING=0
|
||||||
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
|
|
@ -189,6 +189,8 @@ overridden with the `--otlp-service-name` argument
|
||||||
|
|
||||||
![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)
|
![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)
|
||||||
|
|
||||||
|
Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
|
||||||
|
|
||||||
### Local install
|
### Local install
|
||||||
|
|
||||||
You can also opt to install `text-generation-inference` locally.
|
You can also opt to install `text-generation-inference` locally.
|
||||||
|
|
|
@ -153,6 +153,8 @@ impl Client {
|
||||||
}),
|
}),
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
truncate,
|
truncate,
|
||||||
|
// Most request will have that
|
||||||
|
add_special_tokens: true,
|
||||||
// Blocks and slots will be set on the server side if we use paged attention
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
|
|
|
@ -221,6 +221,7 @@ impl Health for ShardedClient {
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
}),
|
}),
|
||||||
truncate: 10,
|
truncate: 10,
|
||||||
|
add_special_tokens: true,
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
|
|
|
@ -53,8 +53,8 @@ utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
minijinja = { version = "2.0.2" }
|
minijinja = { workspace = true }
|
||||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
minijinja-contrib = { workspace = true }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
|
|
|
@ -35,27 +35,15 @@ impl BackendV3 {
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") {
|
let prefix_caching =
|
||||||
matches!(prefix_caching.as_str(), "true" | "1")
|
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
|
||||||
} else {
|
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||||
false
|
let attention: String = std::env::var("ATTENTION").expect("attention env var");
|
||||||
};
|
|
||||||
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
let attention: Attention = attention
|
||||||
attention
|
.parse()
|
||||||
.parse()
|
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
|
||||||
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
let block_size = attention.block_size();
|
||||||
} else if prefix_caching {
|
|
||||||
Attention::FlashInfer
|
|
||||||
} else {
|
|
||||||
Attention::Paged
|
|
||||||
};
|
|
||||||
let block_size = if attention == Attention::FlashDecoding {
|
|
||||||
256
|
|
||||||
} else if attention == Attention::FlashInfer {
|
|
||||||
1
|
|
||||||
} else {
|
|
||||||
16
|
|
||||||
};
|
|
||||||
|
|
||||||
let queue = Queue::new(
|
let queue = Queue::new(
|
||||||
requires_padding,
|
requires_padding,
|
||||||
|
@ -180,6 +168,8 @@ pub(crate) async fn batching_task(
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
// Minimum batch size
|
// Minimum batch size
|
||||||
|
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||||
|
// reallocation when using prefix caching.
|
||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{cmp::min, sync::Arc};
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
use crate::radix::RadixAllocator;
|
use crate::radix::RadixAllocator;
|
||||||
|
@ -137,7 +137,6 @@ pub trait Allocator {
|
||||||
|
|
||||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SimpleAllocator {
|
pub struct SimpleAllocator {
|
||||||
free_blocks: Vec<u32>,
|
free_blocks: Vec<u32>,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
@ -167,7 +166,7 @@ impl Allocator for SimpleAllocator {
|
||||||
None => (tokens, 1),
|
None => (tokens, 1),
|
||||||
Some(window_size) => {
|
Some(window_size) => {
|
||||||
let repeats = (tokens + window_size - 1) / window_size;
|
let repeats = (tokens + window_size - 1) / window_size;
|
||||||
let tokens = min(tokens, window_size);
|
let tokens = core::cmp::min(tokens, window_size);
|
||||||
(tokens, repeats as usize)
|
(tokens, repeats as usize)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -149,6 +149,7 @@ impl Client {
|
||||||
requests.push(Request {
|
requests.push(Request {
|
||||||
id: 0,
|
id: 0,
|
||||||
inputs,
|
inputs,
|
||||||
|
add_special_tokens: true,
|
||||||
input_chunks: Some(Input {
|
input_chunks: Some(Input {
|
||||||
chunks: input_chunks,
|
chunks: input_chunks,
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -222,6 +222,7 @@ impl Health for ShardedClient {
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
}),
|
}),
|
||||||
truncate: 10,
|
truncate: 10,
|
||||||
|
add_special_tokens: true,
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
|
|
|
@ -252,17 +252,14 @@ impl State {
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
next_batch_span.follows_from(Span::current());
|
next_batch_span.follows_from(Span::current());
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
let mut batch = Vec::with_capacity(self.entries.len());
|
||||||
let mut batch_entries =
|
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
|
||||||
|
|
||||||
let mut max_input_length = 0;
|
let mut max_input_length = 0;
|
||||||
let mut prefill_tokens: u32 = 0;
|
let mut prefill_tokens: u32 = 0;
|
||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
let mut max_blocks = 0;
|
let mut max_blocks = 0;
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
// Pop entries starting from the front of the queue
|
||||||
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
|
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
|
@ -276,7 +273,7 @@ impl State {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
|
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
|
||||||
|
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||||
|
@ -290,7 +287,7 @@ impl State {
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
Some(block_allocator) => {
|
Some(_block_allocator) => {
|
||||||
prefill_tokens += entry.request.input_length;
|
prefill_tokens += entry.request.input_length;
|
||||||
let max_new_tokens = match self.window_size {
|
let max_new_tokens = match self.window_size {
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
@ -316,26 +313,67 @@ impl State {
|
||||||
+ self.speculate
|
+ self.speculate
|
||||||
- 1;
|
- 1;
|
||||||
|
|
||||||
match block_allocator
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
.allocate(tokens, entry.request.input_ids.clone())
|
// So no input_ids for the radix tree.
|
||||||
.await
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
{
|
None
|
||||||
None => {
|
} else {
|
||||||
// Entry is over budget
|
entry.request.input_ids.clone()
|
||||||
// Add it back to the front
|
};
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
|
||||||
self.entries.push_front((id, entry));
|
Some((tokens, input_ids))
|
||||||
break 'entry_loop;
|
|
||||||
}
|
|
||||||
Some(block_allocation) => {
|
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
|
||||||
Some(block_allocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
batch.push((id, entry, block_allocation));
|
||||||
|
if Some(batch.len()) == max_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty batch
|
||||||
|
if batch.is_empty() {
|
||||||
|
tracing::debug!("Filterered out all entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX We haven't allocated yet, so we're allowed to ditch the results.
|
||||||
|
// Check if our batch is big enough
|
||||||
|
if let Some(min_size) = min_size {
|
||||||
|
// Batch is too small
|
||||||
|
if batch.len() < min_size {
|
||||||
|
// Add back entries to the queue in the correct order
|
||||||
|
for (id, entry, _) in batch.into_iter().rev() {
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
}
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||||
|
let mut batch_entries =
|
||||||
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
|
for (id, mut entry, block_allocation) in batch {
|
||||||
|
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||||
|
(block_allocation, &self.block_allocator)
|
||||||
|
{
|
||||||
|
match block_allocator.allocate(tokens, input_ids).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
Some(block_allocation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
tracing::debug!("Accepting entry");
|
tracing::debug!("Accepting entry");
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
@ -378,6 +416,7 @@ impl State {
|
||||||
}),
|
}),
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
|
add_special_tokens: entry.request.add_special_tokens,
|
||||||
parameters: Some(NextTokenChooserParameters::from(
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
entry.request.parameters.clone(),
|
entry.request.parameters.clone(),
|
||||||
)),
|
)),
|
||||||
|
@ -394,32 +433,6 @@ impl State {
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
// Insert in batch_entries IntMap
|
// Insert in batch_entries IntMap
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
|
|
||||||
// Check if max_size
|
|
||||||
if Some(batch_requests.len()) == max_size {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty batch
|
|
||||||
if batch_requests.is_empty() {
|
|
||||||
tracing::debug!("Filterered out all entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if our batch is big enough
|
|
||||||
if let Some(min_size) = min_size {
|
|
||||||
// Batch is too small
|
|
||||||
if batch_requests.len() < min_size {
|
|
||||||
// Add back entries to the queue in the correct order
|
|
||||||
for r in batch_requests.into_iter().rev() {
|
|
||||||
let id = r.id;
|
|
||||||
let entry = batch_entries.remove(&id).unwrap();
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
}
|
|
||||||
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
|
@ -512,6 +525,7 @@ mod tests {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
input_ids: Some(Arc::new(vec![])),
|
input_ids: Some(Arc::new(vec![])),
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
|
add_special_tokens: true,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
parameters: ValidParameters {
|
parameters: ValidParameters {
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
|
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||||
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeSet, HashMap},
|
collections::{BTreeSet, HashMap},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
use slotmap::{DefaultKey, SlotMap};
|
|
||||||
|
|
||||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
|
||||||
|
|
||||||
pub struct RadixAllocator {
|
pub struct RadixAllocator {
|
||||||
allocation_id: u64,
|
allocation_id: u64,
|
||||||
|
|
||||||
|
@ -16,26 +14,26 @@ pub struct RadixAllocator {
|
||||||
|
|
||||||
/// Blocks that are immediately available for allocation.
|
/// Blocks that are immediately available for allocation.
|
||||||
free_blocks: Vec<u32>,
|
free_blocks: Vec<u32>,
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
// This isn't used because the prefix need to match without the windowing
|
||||||
|
// mecanism. This at worst is overallocating, not necessarily being wrong.
|
||||||
|
window_size: Option<u32>,
|
||||||
|
|
||||||
|
block_size: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RadixAllocator {
|
impl RadixAllocator {
|
||||||
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {
|
||||||
assert_eq!(
|
|
||||||
block_size, 1,
|
|
||||||
"Radix tree allocator only works with block_size=1, was: {}",
|
|
||||||
block_size
|
|
||||||
);
|
|
||||||
if window_size.is_some() {
|
|
||||||
unimplemented!("Window size not supported in the prefix-caching block allocator yet");
|
|
||||||
}
|
|
||||||
|
|
||||||
RadixAllocator {
|
RadixAllocator {
|
||||||
allocation_id: 0,
|
allocation_id: 0,
|
||||||
allocations: HashMap::new(),
|
allocations: HashMap::new(),
|
||||||
cache_blocks: RadixTrie::new(),
|
cache_blocks: RadixTrie::new(block_size as usize),
|
||||||
|
|
||||||
// Block 0 is reserved for health checks.
|
// Block 0 is reserved for health checks.
|
||||||
free_blocks: (1..n_blocks).collect(),
|
free_blocks: (1..n_blocks).collect(),
|
||||||
|
window_size,
|
||||||
|
block_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,6 +61,7 @@ impl RadixAllocator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Allocator trait
|
||||||
impl Allocator for RadixAllocator {
|
impl Allocator for RadixAllocator {
|
||||||
fn allocate(
|
fn allocate(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
@ -74,22 +73,25 @@ impl Allocator for RadixAllocator {
|
||||||
let node_id = self
|
let node_id = self
|
||||||
.cache_blocks
|
.cache_blocks
|
||||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||||
// Even if this allocation fails below, we need to increase he
|
|
||||||
// refcount to ensure that the prefix that was found is not evicted.
|
|
||||||
|
|
||||||
node_id
|
node_id
|
||||||
} else {
|
} else {
|
||||||
self.cache_blocks.root_id()
|
self.cache_blocks.root_id()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Even if this allocation fails below, we need to increase he
|
||||||
|
// refcount to ensure that the prefix that was found is not evicted.
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
.incref(prefix_node)
|
.incref(prefix_node)
|
||||||
.expect("Failed to increment refcount");
|
.expect("Failed to increment refcount");
|
||||||
|
|
||||||
let prefix_len = blocks.len();
|
let prefix_len = blocks.len() * self.block_size as usize;
|
||||||
let suffix_len = tokens - prefix_len as u32;
|
let suffix_len = tokens - prefix_len as u32;
|
||||||
|
|
||||||
match self.alloc_or_reclaim(suffix_len as usize) {
|
let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size;
|
||||||
|
|
||||||
|
tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}");
|
||||||
|
|
||||||
|
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
None => {
|
None => {
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
|
@ -100,7 +102,20 @@ impl Allocator for RadixAllocator {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1:1 mapping of blocks and slots.
|
// 1:1 mapping of blocks and slots.
|
||||||
let slots = blocks.clone();
|
let slots = if self.block_size == 1 {
|
||||||
|
blocks.clone()
|
||||||
|
} else {
|
||||||
|
let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);
|
||||||
|
'slots: for block_id in &blocks {
|
||||||
|
for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() as u32 == tokens {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slots
|
||||||
|
};
|
||||||
|
|
||||||
let allocation = RadixAllocation {
|
let allocation = RadixAllocation {
|
||||||
prefix_node,
|
prefix_node,
|
||||||
|
@ -108,6 +123,8 @@ impl Allocator for RadixAllocator {
|
||||||
prefill_tokens: prefill_tokens.clone(),
|
prefill_tokens: prefill_tokens.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::debug!("Blocks {blocks:?}");
|
||||||
|
|
||||||
self.allocation_id += 1;
|
self.allocation_id += 1;
|
||||||
self.allocations.insert(self.allocation_id, allocation);
|
self.allocations.insert(self.allocation_id, allocation);
|
||||||
|
|
||||||
|
@ -136,27 +153,38 @@ impl Allocator for RadixAllocator {
|
||||||
// If there are prefill tokens that did not come from the cache,
|
// If there are prefill tokens that did not come from the cache,
|
||||||
// add them to the cache.
|
// add them to the cache.
|
||||||
if prefill_tokens.len() > allocation.cached_prefix_len {
|
if prefill_tokens.len() > allocation.cached_prefix_len {
|
||||||
let prefix_len = self
|
let aligned =
|
||||||
.cache_blocks
|
(prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;
|
||||||
.insert(prefill_tokens, &blocks[..prefill_tokens.len()])
|
if aligned > 0 {
|
||||||
// Unwrap, failing is a programming error.
|
let prefix_len = self
|
||||||
.expect("Failed to store prefill tokens");
|
.cache_blocks
|
||||||
|
.insert(
|
||||||
// We can have a prefill with the following structure:
|
&prefill_tokens[..aligned],
|
||||||
//
|
&blocks[..aligned / self.block_size as usize],
|
||||||
// |---| From the prefix cache.
|
)
|
||||||
// A B C D E F G
|
// Unwrap, failing is a programming error.
|
||||||
//|--------| Found in the trie during insertion.
|
.expect("Failed to store prefill tokens");
|
||||||
//
|
// We can have a prefill with the following structure:
|
||||||
// This means that while processing this request there was a
|
//
|
||||||
// partially overlapping request that had A..=E in its
|
// |---| From the prefix cache.
|
||||||
// prefill. In this case we need to free the blocks D E.
|
// A B C D E F G
|
||||||
self.free_blocks
|
//|--------| Found in the trie during insertion.
|
||||||
.extend(&blocks[allocation.cached_prefix_len..prefix_len]);
|
//
|
||||||
|
// This means that while processing this request there was a
|
||||||
|
// partially overlapping request that had A..=E in its
|
||||||
|
// prefill. In this case we need to free the blocks D E.
|
||||||
|
if prefix_len > allocation.cached_prefix_len {
|
||||||
|
self.free_blocks.extend(
|
||||||
|
&blocks[allocation.cached_prefix_len / self.block_size as usize
|
||||||
|
..prefix_len / self.block_size as usize],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free non-prefill blocks.
|
// Free non-prefill blocks.
|
||||||
self.free_blocks.extend(&blocks[prefill_tokens.len()..]);
|
self.free_blocks
|
||||||
|
.extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);
|
||||||
} else {
|
} else {
|
||||||
self.free_blocks.extend(blocks);
|
self.free_blocks.extend(blocks);
|
||||||
}
|
}
|
||||||
|
@ -204,16 +232,14 @@ pub struct RadixTrie {
|
||||||
/// Time as a monotonically increating counter to avoid the system
|
/// Time as a monotonically increating counter to avoid the system
|
||||||
/// call that a real time lookup would require.
|
/// call that a real time lookup would require.
|
||||||
time: u64,
|
time: u64,
|
||||||
}
|
|
||||||
impl Default for RadixTrie {
|
/// All blocks need to be aligned with this
|
||||||
fn default() -> Self {
|
block_size: usize,
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RadixTrie {
|
impl RadixTrie {
|
||||||
/// Construct a new radix trie.
|
/// Construct a new radix trie.
|
||||||
pub fn new() -> Self {
|
pub fn new(block_size: usize) -> Self {
|
||||||
let root = TrieNode::new(vec![], vec![], 0, None);
|
let root = TrieNode::new(vec![], vec![], 0, None);
|
||||||
let mut nodes = SlotMap::new();
|
let mut nodes = SlotMap::new();
|
||||||
let root = nodes.insert(root);
|
let root = nodes.insert(root);
|
||||||
|
@ -222,13 +248,14 @@ impl RadixTrie {
|
||||||
nodes,
|
nodes,
|
||||||
root,
|
root,
|
||||||
time: 0,
|
time: 0,
|
||||||
|
block_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find the prefix of the given tokens.
|
/// Find the prefix of the given tokens.
|
||||||
///
|
///
|
||||||
/// The blocks corresponding to the part of the prefix that could be found
|
/// The blocks corresponding to the part of the prefix that could be found
|
||||||
/// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
/// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.
|
||||||
/// Returns the identifier of the trie node that contains the longest
|
/// Returns the identifier of the trie node that contains the longest
|
||||||
/// prefix. The node identifier can be used by callers to e.g. increase its
|
/// prefix. The node identifier can be used by callers to e.g. increase its
|
||||||
/// reference count.
|
/// reference count.
|
||||||
|
@ -246,8 +273,9 @@ impl RadixTrie {
|
||||||
if let Some(&child_id) = node.children.get(&key[0]) {
|
if let Some(&child_id) = node.children.get(&key[0]) {
|
||||||
self.update_access_time(child_id);
|
self.update_access_time(child_id);
|
||||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||||
let shared_prefix_len = child.key.shared_prefix_len(key);
|
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||||
blocks.extend(&child.blocks[..shared_prefix_len]);
|
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||||
|
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||||
|
|
||||||
let key = &key[shared_prefix_len..];
|
let key = &key[shared_prefix_len..];
|
||||||
if !key.is_empty() {
|
if !key.is_empty() {
|
||||||
|
@ -276,6 +304,11 @@ impl RadixTrie {
|
||||||
|
|
||||||
node.ref_count -= 1;
|
node.ref_count -= 1;
|
||||||
if node.ref_count == 0 {
|
if node.ref_count == 0 {
|
||||||
|
assert!(
|
||||||
|
node.children.is_empty(),
|
||||||
|
"Nodes with children must have refcount > 0"
|
||||||
|
);
|
||||||
|
|
||||||
self.leaves.insert((node.last_accessed, node_id));
|
self.leaves.insert((node.last_accessed, node_id));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,7 +336,7 @@ impl RadixTrie {
|
||||||
/// Evict `n_blocks` from the trie.
|
/// Evict `n_blocks` from the trie.
|
||||||
///
|
///
|
||||||
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
/// Returns the evicted blocks. When the length is less than `n_blocks`,
|
||||||
/// not enough blocks could beevicted.
|
/// not enough blocks could be evicted.
|
||||||
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {
|
||||||
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
// NOTE: we don't return Result here. If any of the unwrapping fails,
|
||||||
// it's a programming error in the trie implementation, not a user
|
// it's a programming error in the trie implementation, not a user
|
||||||
|
@ -318,6 +351,12 @@ impl RadixTrie {
|
||||||
let blocks_needed = n_blocks - evicted.len();
|
let blocks_needed = n_blocks - evicted.len();
|
||||||
|
|
||||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||||
|
assert_eq!(
|
||||||
|
node.ref_count, 0,
|
||||||
|
"Leaf must have refcount of 0, got {}",
|
||||||
|
node.ref_count
|
||||||
|
);
|
||||||
|
|
||||||
if blocks_needed >= node.blocks.len() {
|
if blocks_needed >= node.blocks.len() {
|
||||||
// We need to evict the whole node if we need more blocks than it has.
|
// We need to evict the whole node if we need more blocks than it has.
|
||||||
let node = self.remove_node(node_id);
|
let node = self.remove_node(node_id);
|
||||||
|
@ -348,7 +387,8 @@ impl RadixTrie {
|
||||||
/// the first 10 elements of the tree **the blocks are not updated**.
|
/// the first 10 elements of the tree **the blocks are not updated**.
|
||||||
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {
|
||||||
self.time += 1;
|
self.time += 1;
|
||||||
self.insert_(self.root, tokens, blocks)
|
let common = self.insert_(self.root, tokens, blocks)?;
|
||||||
|
Ok(common)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insertion worker.
|
/// Insertion worker.
|
||||||
|
@ -362,7 +402,7 @@ impl RadixTrie {
|
||||||
// the part of the prefix that is already in the trie to detect
|
// the part of the prefix that is already in the trie to detect
|
||||||
// mismatches.
|
// mismatches.
|
||||||
|
|
||||||
if tokens.len() != blocks.len() {
|
if tokens.len() != blocks.len() * self.block_size {
|
||||||
return Err(TrieError::BlockTokenCountMismatch);
|
return Err(TrieError::BlockTokenCountMismatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,10 +413,10 @@ impl RadixTrie {
|
||||||
.get_mut(child_id)
|
.get_mut(child_id)
|
||||||
// Unwrap here, since failure is a bug.
|
// Unwrap here, since failure is a bug.
|
||||||
.expect("Child node does not exist");
|
.expect("Child node does not exist");
|
||||||
let shared_prefix_len = child.key.shared_prefix_len(tokens);
|
let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);
|
||||||
|
|
||||||
// We are done, the prefix is already in the trie.
|
// We are done, the prefix is already in the trie.
|
||||||
if shared_prefix_len == tokens.len() {
|
if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {
|
||||||
return Ok(shared_prefix_len);
|
return Ok(shared_prefix_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -386,7 +426,7 @@ impl RadixTrie {
|
||||||
+ self.insert_(
|
+ self.insert_(
|
||||||
child_id,
|
child_id,
|
||||||
&tokens[shared_prefix_len..],
|
&tokens[shared_prefix_len..],
|
||||||
&blocks[shared_prefix_len..],
|
&blocks[shared_prefix_len / self.block_size..],
|
||||||
)?);
|
)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -395,7 +435,7 @@ impl RadixTrie {
|
||||||
// remainder of the prefix into the node again
|
// remainder of the prefix into the node again
|
||||||
let child_id = self.split_node(child_id, shared_prefix_len);
|
let child_id = self.split_node(child_id, shared_prefix_len);
|
||||||
let key = &tokens[shared_prefix_len..];
|
let key = &tokens[shared_prefix_len..];
|
||||||
let blocks = &blocks[shared_prefix_len..];
|
let blocks = &blocks[shared_prefix_len / self.block_size..];
|
||||||
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)
|
||||||
} else {
|
} else {
|
||||||
self.add_node(node_id, tokens, blocks);
|
self.add_node(node_id, tokens, blocks);
|
||||||
|
@ -472,12 +512,16 @@ impl RadixTrie {
|
||||||
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
fn remove_node(&mut self, node_id: NodeId) -> TrieNode {
|
||||||
// Unwrap here, passing in an unknown id is a programming error.
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
let node = self.nodes.remove(node_id).expect("Unknown node");
|
let node = self.nodes.remove(node_id).expect("Unknown node");
|
||||||
|
assert!(
|
||||||
|
node.children.is_empty(),
|
||||||
|
"Tried to remove a node with {} children",
|
||||||
|
node.children.len()
|
||||||
|
);
|
||||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
parent.children.remove(&node.key[0]);
|
parent.children.remove(&node.key[0]);
|
||||||
self.decref(parent_id)
|
self.decref(parent_id)
|
||||||
.expect("Failed to decrease parent refcount");
|
.expect("Failed to decrease parent refcount");
|
||||||
self.nodes.remove(node_id);
|
|
||||||
node
|
node
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -549,34 +593,56 @@ impl TrieNode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper trait to get the length of the shared prefix of two sequences.
|
fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {
|
||||||
trait SharedPrefixLen {
|
let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();
|
||||||
fn shared_prefix_len(&self, other: &Self) -> usize;
|
// NOTE: this is the case because the child node was chosen based on
|
||||||
}
|
// matching the first character of the key/prefix.
|
||||||
|
assert!(full > 0, "Prefixes must at least share 1 token");
|
||||||
impl<T> SharedPrefixLen for [T]
|
(full / block_size) * block_size
|
||||||
where
|
|
||||||
T: PartialEq,
|
|
||||||
{
|
|
||||||
fn shared_prefix_len(&self, other: &Self) -> usize {
|
|
||||||
self.iter().zip(other).take_while(|(a, b)| a == b).count()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::block_allocator::Allocator;
|
use super::*;
|
||||||
|
|
||||||
use super::RadixAllocator;
|
#[test]
|
||||||
|
fn allocator_block_size() {
|
||||||
|
let mut cache = RadixAllocator::new(2, 12, None);
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||||
|
assert_eq!(allocation.prefix_len, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allocator_block_size_non_aligned() {
|
||||||
|
let mut cache = RadixAllocator::new(2, 12, None);
|
||||||
|
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||||
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||||
|
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||||
|
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||||
|
assert_eq!(allocation.prefix_len, 2);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allocator_reuses_prefixes() {
|
fn allocator_reuses_prefixes() {
|
||||||
let mut cache = RadixAllocator::new(1, 12, None);
|
let mut cache = RadixAllocator::new(1, 12, None);
|
||||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||||
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);
|
||||||
assert_eq!(allocation.slots, allocation.slots);
|
assert_eq!(allocation.blocks, allocation.slots);
|
||||||
assert_eq!(allocation.prefix_len, 0);
|
assert_eq!(allocation.prefix_len, 0);
|
||||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||||
|
|
||||||
|
@ -665,7 +731,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn trie_insertions_have_correct_prefix_len() {
|
fn trie_insertions_have_correct_prefix_len() {
|
||||||
let mut trie = super::RadixTrie::new();
|
let mut trie = RadixTrie::new(1);
|
||||||
|
|
||||||
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);
|
||||||
|
|
||||||
|
@ -686,9 +752,33 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trie_insertions_block_size() {
|
||||||
|
let mut trie = RadixTrie::new(2);
|
||||||
|
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Already exists.
|
||||||
|
// But needs to be block_size aligned
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);
|
||||||
|
|
||||||
|
// Completely new at root-level
|
||||||
|
assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);
|
||||||
|
|
||||||
|
// Contains full prefix, but longer.
|
||||||
|
assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);
|
||||||
|
|
||||||
|
// Shares partial prefix, we need a split.
|
||||||
|
assert_eq!(
|
||||||
|
trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])
|
||||||
|
.unwrap(),
|
||||||
|
2
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn trie_get_returns_correct_blocks() {
|
fn trie_get_returns_correct_blocks() {
|
||||||
let mut trie = super::RadixTrie::new();
|
let mut trie = RadixTrie::new(1);
|
||||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();
|
||||||
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();
|
||||||
|
@ -722,7 +812,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn trie_evict_removes_correct_blocks() {
|
fn trie_evict_removes_correct_blocks() {
|
||||||
let mut trie = super::RadixTrie::new();
|
let mut trie = RadixTrie::new(1);
|
||||||
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();
|
||||||
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
@ -148,6 +148,7 @@ async fn prefill(
|
||||||
}),
|
}),
|
||||||
inputs: sequence.clone(),
|
inputs: sequence.clone(),
|
||||||
truncate: sequence_length,
|
truncate: sequence_length,
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: Some(parameters.clone()),
|
parameters: Some(parameters.clone()),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: decode_length,
|
max_new_tokens: decode_length,
|
||||||
|
|
|
@ -757,7 +757,12 @@ class AsyncClient:
|
||||||
continue
|
continue
|
||||||
payload = byte_payload.decode("utf-8")
|
payload = byte_payload.decode("utf-8")
|
||||||
if payload.startswith("data:"):
|
if payload.startswith("data:"):
|
||||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
payload_data = (
|
||||||
|
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
|
||||||
|
)
|
||||||
|
if payload_data == "[DONE]":
|
||||||
|
break
|
||||||
|
json_payload = json.loads(payload_data)
|
||||||
try:
|
try:
|
||||||
response = ChatCompletionChunk(**json_payload)
|
response = ChatCompletionChunk(**json_payload)
|
||||||
yield response
|
yield response
|
||||||
|
|
|
@ -556,6 +556,37 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"/v1/models": {
|
||||||
|
"get": {
|
||||||
|
"tags": [
|
||||||
|
"Text Generation Inference"
|
||||||
|
],
|
||||||
|
"summary": "Get model info",
|
||||||
|
"operationId": "openai_get_model_info",
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "Served model info",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ModelInfo"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"404": {
|
||||||
|
"description": "Model not found",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/ErrorResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"components": {
|
"components": {
|
||||||
|
@ -924,7 +955,7 @@
|
||||||
"tool_prompt": {
|
"tool_prompt": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "A prompt to be appended before the tools",
|
"description": "A prompt to be appended before the tools",
|
||||||
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
|
"example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.",
|
||||||
"nullable": true
|
"nullable": true
|
||||||
},
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
|
@ -1747,6 +1778,35 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"ModelInfo": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"id",
|
||||||
|
"object",
|
||||||
|
"created",
|
||||||
|
"owned_by"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"created": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 1686935002,
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "gpt2"
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "model"
|
||||||
|
},
|
||||||
|
"owned_by": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "openai"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"OutputMessage": {
|
"OutputMessage": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
|
|
|
@ -71,6 +71,8 @@
|
||||||
title: How Guidance Works (via outlines)
|
title: How Guidance Works (via outlines)
|
||||||
- local: conceptual/lora
|
- local: conceptual/lora
|
||||||
title: LoRA (Low-Rank Adaptation)
|
title: LoRA (Low-Rank Adaptation)
|
||||||
|
- local: conceptual/external
|
||||||
|
title: External Resources
|
||||||
|
|
||||||
|
|
||||||
title: Conceptual Guides
|
title: Conceptual Guides
|
||||||
|
|
|
@ -157,7 +157,12 @@ from huggingface_hub import InferenceClient
|
||||||
|
|
||||||
client = InferenceClient("http://localhost:3000")
|
client = InferenceClient("http://localhost:3000")
|
||||||
|
|
||||||
regexp = "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)"
|
section_regex = "(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
|
||||||
|
regexp = f"HELLO\.{section_regex}\.WORLD\.{section_regex}"
|
||||||
|
|
||||||
|
# This is a more realistic example of an ip address regex
|
||||||
|
# regexp = f"{section_regex}\.{section_regex}\.{section_regex}\.{section_regex}"
|
||||||
|
|
||||||
|
|
||||||
resp = client.text_generation(
|
resp = client.text_generation(
|
||||||
f"Whats Googles DNS? Please use the following regex: {regexp}",
|
f"Whats Googles DNS? Please use the following regex: {regexp}",
|
||||||
|
@ -170,7 +175,7 @@ resp = client.text_generation(
|
||||||
|
|
||||||
|
|
||||||
print(resp)
|
print(resp)
|
||||||
# 7.1.1.1
|
# HELLO.255.WORLD.255
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
# External Resources
|
||||||
|
|
||||||
|
- Adyen wrote a detailed article about the interplay between TGI's main components: router and server.
|
||||||
|
[LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)
|
|
@ -1,5 +1,6 @@
|
||||||
# Streaming
|
# Streaming
|
||||||
|
|
||||||
|
|
||||||
## What is Streaming?
|
## What is Streaming?
|
||||||
|
|
||||||
Token streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience.
|
Token streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience.
|
||||||
|
|
|
@ -12,7 +12,24 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||||
docker run --rm --privileged --cap-add=sys_nice \
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
--device=/dev/dri \
|
--device=/dev/dri \
|
||||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel \
|
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
|
||||||
|
--model-id $model --cuda-graphs 0
|
||||||
|
```
|
||||||
|
|
||||||
|
# Using TGI with Intel CPUs
|
||||||
|
|
||||||
|
Intel® Extension for PyTorch (IPEX) also provides further optimizations for Intel CPUs. The IPEX provides optimization operations such as flash attention, page attention, Add + LayerNorm, ROPE and more.
|
||||||
|
|
||||||
|
On a server powered by Intel CPU, TGI can be launched with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||||
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
|
||||||
--model-id $model --cuda-graphs 0
|
--model-id $model --cuda-graphs 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
154
flake.lock
154
flake.lock
|
@ -492,24 +492,6 @@
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"flake-utils_7": {
|
|
||||||
"inputs": {
|
|
||||||
"systems": "systems_7"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1710146030,
|
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "flake-utils",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"gitignore": {
|
"gitignore": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"nixpkgs": [
|
"nixpkgs": [
|
||||||
|
@ -594,27 +576,6 @@
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nix-github-actions": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1703863825,
|
|
||||||
"narHash": "sha256-rXwqjtwiGKJheXB43ybM8NwWB8rO2dSRrEqes0S7F5Y=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"rev": "5163432afc817cf8bd1f031418d1869e4c9d5547",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "nix-github-actions",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nix-test-runner": {
|
"nix-test-runner": {
|
||||||
"flake": false,
|
"flake": false,
|
||||||
"locked": {
|
"locked": {
|
||||||
|
@ -739,58 +700,20 @@
|
||||||
},
|
},
|
||||||
"nixpkgs_6": {
|
"nixpkgs_6": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1719763542,
|
"lastModified": 1723912943,
|
||||||
"narHash": "sha256-mXkOj9sJ0f69Nkc2dGGOWtof9d1YNY8Le/Hia3RN+8Q=",
|
"narHash": "sha256-39F9GzyhxYcY3wTeKuEFWRJWcrGBosO4nf4xzMTWZX8=",
|
||||||
"owner": "NixOS",
|
"owner": "danieldk",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "e6cdd8a11b26b4d60593733106042141756b54a3",
|
"rev": "b82cdca86dbb30013b76c4b55d48806476820a5c",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "NixOS",
|
"owner": "danieldk",
|
||||||
"ref": "nixos-unstable-small",
|
"ref": "cuda-12.4",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nixpkgs_7": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1723418128,
|
|
||||||
"narHash": "sha256-k1pEqsnB6ikZyasXbtV6A9akPZMKlsyENPDUA6PXoJo=",
|
|
||||||
"owner": "nixos",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"rev": "129f579cbb5b4c1ad258fd96bdfb78eb14802727",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nixos",
|
|
||||||
"ref": "nixos-unstable-small",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"poetry2nix": {
|
|
||||||
"inputs": {
|
|
||||||
"flake-utils": "flake-utils_7",
|
|
||||||
"nix-github-actions": "nix-github-actions",
|
|
||||||
"nixpkgs": "nixpkgs_6",
|
|
||||||
"systems": "systems_8",
|
|
||||||
"treefmt-nix": "treefmt-nix"
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1723512448,
|
|
||||||
"narHash": "sha256-VSTtxGKre1p6zd6ACuBmfDcR+BT9+ml8Y3KrSbfGFYU=",
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"rev": "ed52f844c4dd04dde45550c3189529854384124e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-community",
|
|
||||||
"repo": "poetry2nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"pre-commit-hooks": {
|
"pre-commit-hooks": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"flake-compat": [
|
"flake-compat": [
|
||||||
|
@ -900,7 +823,6 @@
|
||||||
"tgi-nix",
|
"tgi-nix",
|
||||||
"nixpkgs"
|
"nixpkgs"
|
||||||
],
|
],
|
||||||
"poetry2nix": "poetry2nix",
|
|
||||||
"rust-overlay": "rust-overlay",
|
"rust-overlay": "rust-overlay",
|
||||||
"tgi-nix": "tgi-nix"
|
"tgi-nix": "tgi-nix"
|
||||||
}
|
}
|
||||||
|
@ -913,11 +835,11 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723515680,
|
"lastModified": 1724638882,
|
||||||
"narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=",
|
"narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3",
|
"rev": "19b70f147b9c67a759e35824b241f1ed92e46694",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -1016,46 +938,17 @@
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"systems_7": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"systems_8": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1681028828,
|
|
||||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
|
||||||
"owner": "nix-systems",
|
|
||||||
"repo": "default",
|
|
||||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"id": "systems",
|
|
||||||
"type": "indirect"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"tgi-nix": {
|
"tgi-nix": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"flake-compat": "flake-compat_4",
|
"flake-compat": "flake-compat_4",
|
||||||
"nixpkgs": "nixpkgs_7"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723532088,
|
"lastModified": 1725011596,
|
||||||
"narHash": "sha256-6h/P/BkFDw8txlikonKXp5IbluHSPhHJTQRftJLkbLQ=",
|
"narHash": "sha256-zfq8lOXFgJnKxxsqSelHuKUvhxgH3cEmLoAgsOO62Cg=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"rev": "32335a37ce0f703bab901baf7b74eb11e9972d5f",
|
"rev": "717c2b07e38538abf05237cca65b2d1363c2c9af",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -1063,27 +956,6 @@
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"treefmt-nix": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": [
|
|
||||||
"poetry2nix",
|
|
||||||
"nixpkgs"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1719749022,
|
|
||||||
"narHash": "sha256-ddPKHcqaKCIFSFc/cvxS14goUhCOAwsM1PbMr0ZtHMg=",
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"rev": "8df5ff62195d4e67e2264df0b7f5e8c9995fd0bd",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"owner": "numtide",
|
|
||||||
"repo": "treefmt-nix",
|
|
||||||
"type": "github"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"root": "root",
|
"root": "root",
|
||||||
|
|
96
flake.nix
96
flake.nix
|
@ -8,7 +8,6 @@
|
||||||
tgi-nix.url = "github:danieldk/tgi-nix";
|
tgi-nix.url = "github:danieldk/tgi-nix";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
poetry2nix.url = "github:nix-community/poetry2nix";
|
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
url = "github:oxalica/rust-overlay";
|
url = "github:oxalica/rust-overlay";
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
|
@ -23,7 +22,6 @@
|
||||||
flake-utils,
|
flake-utils,
|
||||||
rust-overlay,
|
rust-overlay,
|
||||||
tgi-nix,
|
tgi-nix,
|
||||||
poetry2nix,
|
|
||||||
}:
|
}:
|
||||||
flake-utils.lib.eachDefaultSystem (
|
flake-utils.lib.eachDefaultSystem (
|
||||||
system:
|
system:
|
||||||
|
@ -33,25 +31,40 @@
|
||||||
src = ./.;
|
src = ./.;
|
||||||
additionalCargoNixArgs = [ "--all-features" ];
|
additionalCargoNixArgs = [ "--all-features" ];
|
||||||
};
|
};
|
||||||
config = {
|
|
||||||
allowUnfree = true;
|
|
||||||
cudaSupport = true;
|
|
||||||
};
|
|
||||||
pkgs = import nixpkgs {
|
pkgs = import nixpkgs {
|
||||||
inherit config system;
|
inherit system;
|
||||||
|
inherit (tgi-nix.lib) config;
|
||||||
overlays = [
|
overlays = [
|
||||||
rust-overlay.overlays.default
|
rust-overlay.overlays.default
|
||||||
tgi-nix.overlay
|
tgi-nix.overlays.default
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
inherit (poetry2nix.lib.mkPoetry2Nix { inherit pkgs; }) mkPoetryEditablePackage;
|
|
||||||
text-generation-server = mkPoetryEditablePackage { editablePackageSources = ./server; };
|
|
||||||
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };
|
||||||
|
benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override {
|
||||||
|
inherit crateOverrides;
|
||||||
|
};
|
||||||
|
launcher = cargoNix.workspaceMembers.text-generation-launcher.build.override {
|
||||||
|
inherit crateOverrides;
|
||||||
|
};
|
||||||
|
router = cargoNix.workspaceMembers.text-generation-router-v3.build.override {
|
||||||
|
inherit crateOverrides;
|
||||||
|
};
|
||||||
|
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
devShells.default =
|
devShells = with pkgs; rec {
|
||||||
with pkgs;
|
default = pure;
|
||||||
mkShell {
|
|
||||||
|
pure = mkShell {
|
||||||
|
buildInputs = [
|
||||||
|
benchmark
|
||||||
|
launcher
|
||||||
|
router
|
||||||
|
server
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
impure = mkShell {
|
||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
openssl.dev
|
openssl.dev
|
||||||
|
@ -62,51 +75,46 @@
|
||||||
"rust-src"
|
"rust-src"
|
||||||
];
|
];
|
||||||
})
|
})
|
||||||
|
protobuf
|
||||||
]
|
]
|
||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
venvShellHook
|
venvShellHook
|
||||||
|
docker
|
||||||
pip
|
pip
|
||||||
|
ipdb
|
||||||
causal-conv1d
|
pyright
|
||||||
click
|
pytest
|
||||||
einops
|
pytest-asyncio
|
||||||
exllamav2
|
ruff
|
||||||
fbgemm-gpu
|
syrupy
|
||||||
flashinfer
|
|
||||||
flash-attn
|
|
||||||
flash-attn-layer-norm
|
|
||||||
flash-attn-rotary
|
|
||||||
grpc-interceptor
|
|
||||||
grpcio-reflection
|
|
||||||
grpcio-status
|
|
||||||
grpcio-tools
|
|
||||||
hf-transfer
|
|
||||||
loguru
|
|
||||||
mamba-ssm
|
|
||||||
marlin-kernels
|
|
||||||
opentelemetry-api
|
|
||||||
opentelemetry-exporter-otlp
|
|
||||||
opentelemetry-instrumentation-grpc
|
|
||||||
opentelemetry-semantic-conventions
|
|
||||||
peft
|
|
||||||
tokenizers
|
|
||||||
torch
|
|
||||||
transformers
|
|
||||||
vllm
|
|
||||||
|
|
||||||
(cargoNix.workspaceMembers.text-generation-launcher.build.override { inherit crateOverrides; })
|
|
||||||
(cargoNix.workspaceMembers.text-generation-router-v3.build.override { inherit crateOverrides; })
|
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
inputsFrom = [ server ];
|
||||||
|
|
||||||
venvDir = "./.venv";
|
venvDir = "./.venv";
|
||||||
|
|
||||||
postVenv = ''
|
postVenvCreation = ''
|
||||||
unset SOURCE_DATE_EPOCH
|
unset SOURCE_DATE_EPOCH
|
||||||
|
( cd server ; python -m pip install --no-dependencies -e . )
|
||||||
|
( cd clients/python ; python -m pip install --no-dependencies -e . )
|
||||||
'';
|
'';
|
||||||
postShellHook = ''
|
postShellHook = ''
|
||||||
unset SOURCE_DATE_EPOCH
|
unset SOURCE_DATE_EPOCH
|
||||||
|
export PATH=$PATH:~/.cargo/bin
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
packages.default = pkgs.writeShellApplication {
|
||||||
|
name = "text-generation-inference";
|
||||||
|
runtimeInputs = [
|
||||||
|
server
|
||||||
|
router
|
||||||
|
];
|
||||||
|
text = ''
|
||||||
|
${launcher}/bin/text-generation-launcher "$@"
|
||||||
|
'';
|
||||||
|
};
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,6 +64,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||||
self,
|
self,
|
||||||
data,
|
data,
|
||||||
*,
|
*,
|
||||||
|
include=None,
|
||||||
exclude=None,
|
exclude=None,
|
||||||
matcher=None,
|
matcher=None,
|
||||||
):
|
):
|
||||||
|
@ -79,7 +80,12 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||||
data = [d.model_dump() for d in data]
|
data = [d.model_dump() for d in data]
|
||||||
|
|
||||||
data = self._filter(
|
data = self._filter(
|
||||||
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
|
data=data,
|
||||||
|
depth=0,
|
||||||
|
path=(),
|
||||||
|
exclude=exclude,
|
||||||
|
include=include,
|
||||||
|
matcher=matcher,
|
||||||
)
|
)
|
||||||
return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
|
return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
|
||||||
|
|
||||||
|
@ -257,7 +263,7 @@ class IgnoreLogProbResponseComparator(ResponseComparator):
|
||||||
|
|
||||||
class LauncherHandle:
|
class LauncherHandle:
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
self.client = AsyncClient(f"http://localhost:{port}")
|
self.client = AsyncClient(f"http://localhost:{port}", timeout=30)
|
||||||
|
|
||||||
def _inner_health(self):
|
def _inner_health(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"message": {
|
"message": {
|
||||||
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas",
|
"content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C",
|
||||||
"name": null,
|
"name": null,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"tool_calls": null
|
"tool_calls": null
|
||||||
|
@ -13,14 +13,14 @@
|
||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1716553098,
|
"created": 1724792495,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "chat.completion",
|
||||||
"system_fingerprint": "2.0.5-dev0-native",
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": 100,
|
"completion_tokens": 100,
|
||||||
"prompt_tokens": 62,
|
"prompt_tokens": 61,
|
||||||
"total_tokens": 162
|
"total_tokens": 161
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,11 +8,11 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -23,11 +23,11 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -38,11 +38,11 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -53,11 +53,11 @@
|
||||||
"text": "hd"
|
"text": "hd"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -68,11 +68,11 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -83,11 +83,11 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -98,11 +98,11 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -113,11 +113,11 @@
|
||||||
"text": "aho"
|
"text": "aho"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -128,11 +128,11 @@
|
||||||
"text": "2"
|
"text": "2"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -143,11 +143,11 @@
|
||||||
"text": "2"
|
"text": "2"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -158,11 +158,11 @@
|
||||||
"text": "2"
|
"text": "2"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -173,11 +173,11 @@
|
||||||
"text": "ima"
|
"text": "ima"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -188,11 +188,11 @@
|
||||||
"text": "."
|
"text": "."
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -203,11 +203,11 @@
|
||||||
"text": "."
|
"text": "."
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -218,11 +218,11 @@
|
||||||
"text": "."
|
"text": "."
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -233,11 +233,11 @@
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -248,11 +248,11 @@
|
||||||
"text": " Sarah"
|
"text": " Sarah"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -263,11 +263,11 @@
|
||||||
"text": " Yes"
|
"text": " Yes"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -278,11 +278,11 @@
|
||||||
"text": " And"
|
"text": " And"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -293,11 +293,11 @@
|
||||||
"text": "i"
|
"text": "i"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -308,11 +308,11 @@
|
||||||
"text": "'"
|
"text": "'"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -323,11 +323,11 @@
|
||||||
"text": ","
|
"text": ","
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -338,11 +338,11 @@
|
||||||
"text": " what"
|
"text": " what"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -353,11 +353,11 @@
|
||||||
"text": "'"
|
"text": "'"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -368,11 +368,11 @@
|
||||||
"text": "s"
|
"text": "s"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -383,11 +383,11 @@
|
||||||
"text": " Moh"
|
"text": " Moh"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -398,11 +398,11 @@
|
||||||
"text": " is"
|
"text": " is"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -413,11 +413,11 @@
|
||||||
"text": "m"
|
"text": "m"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -428,11 +428,11 @@
|
||||||
"text": " Room"
|
"text": " Room"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -443,11 +443,11 @@
|
||||||
"text": "s"
|
"text": "s"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -458,11 +458,11 @@
|
||||||
"text": " the"
|
"text": " the"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -473,11 +473,11 @@
|
||||||
"text": " tired"
|
"text": " tired"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -488,11 +488,11 @@
|
||||||
"text": ":"
|
"text": ":"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -503,11 +503,11 @@
|
||||||
"text": "'"
|
"text": "'"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -518,11 +518,11 @@
|
||||||
"text": " capital"
|
"text": " capital"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -530,73 +530,73 @@
|
||||||
"finish_reason": "",
|
"finish_reason": "",
|
||||||
"index": 3,
|
"index": 3,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": " of"
|
"text": ","
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "",
|
"finish_reason": "length",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": " She"
|
"text": " She"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "",
|
"finish_reason": "length",
|
||||||
"index": 1,
|
"index": 1,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": " scale"
|
"text": " scale"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "",
|
"finish_reason": "length",
|
||||||
"index": 2,
|
"index": 2,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"finish_reason": "",
|
"finish_reason": "length",
|
||||||
"index": 3,
|
"index": 3,
|
||||||
"logprobs": null,
|
"logprobs": null,
|
||||||
"text": " being"
|
"text": " its"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1713284431,
|
"created": 1724833943,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"system_fingerprint": "2.0.1-native"
|
"system_fingerprint": "2.2.1-dev0-native"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3102,
|
"id": 3102,
|
||||||
"logprob": -11.1875,
|
"logprob": -11.25,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -24,66 +24,66 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 185,
|
"id": 185,
|
||||||
"logprob": -1.5546875,
|
"logprob": -1.546875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 549,
|
"id": 549,
|
||||||
"logprob": -2.84375,
|
"logprob": -2.859375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "The"
|
"text": "The"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1727,
|
"id": 1727,
|
||||||
"logprob": -2.34375,
|
"logprob": -2.484375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3102,
|
"id": 3102,
|
||||||
"logprob": -0.8359375,
|
"logprob": -0.83203125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 317,
|
||||||
"logprob": -1.0859375,
|
"logprob": -1.1484375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 254,
|
"id": 245,
|
||||||
"logprob": -1.5390625,
|
"logprob": -1.578125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1022,
|
"id": 3412,
|
||||||
"logprob": -1.1875,
|
"logprob": -2.578125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " first"
|
"text": " document"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3458,
|
"id": 344,
|
||||||
"logprob": -0.35546875,
|
"logprob": -1.125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " step"
|
"text": " that"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 279,
|
"id": 317,
|
||||||
"logprob": -0.8828125,
|
"logprob": -1.6953125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " in"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 254,
|
"id": 1222,
|
||||||
"logprob": -0.71484375,
|
"logprob": -1.71875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " used"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nThe test request is the first step in the"
|
"generated_text": "\nThe test request is a document that is used"
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,56 +37,56 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1727,
|
"id": 1727,
|
||||||
"logprob": -2.359375,
|
"logprob": -2.4375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3102,
|
"id": 3102,
|
||||||
"logprob": -0.83203125,
|
"logprob": -0.83984375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 317,
|
||||||
"logprob": -1.125,
|
"logprob": -1.1328125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 245,
|
"id": 254,
|
||||||
"logprob": -1.5703125,
|
"logprob": -1.515625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3412,
|
"id": 1022,
|
||||||
"logprob": -2.578125,
|
"logprob": -1.15625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " document"
|
"text": " first"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 344,
|
"id": 3458,
|
||||||
"logprob": -1.125,
|
"logprob": -0.3671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " step"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 279,
|
||||||
"logprob": -1.6953125,
|
"logprob": -0.88671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1222,
|
"id": 254,
|
||||||
"logprob": -1.75,
|
"logprob": -0.69140625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " used"
|
"text": " the"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nThe test request is a document that is used"
|
"generated_text": "\nThe test request is the first step in the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -126,56 +126,56 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1727,
|
"id": 1727,
|
||||||
"logprob": -2.359375,
|
"logprob": -2.4375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3102,
|
"id": 3102,
|
||||||
"logprob": -0.83203125,
|
"logprob": -0.83984375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 317,
|
||||||
"logprob": -1.125,
|
"logprob": -1.1328125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 245,
|
"id": 254,
|
||||||
"logprob": -1.5703125,
|
"logprob": -1.515625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3412,
|
"id": 1022,
|
||||||
"logprob": -2.578125,
|
"logprob": -1.15625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " document"
|
"text": " first"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 344,
|
"id": 3458,
|
||||||
"logprob": -1.125,
|
"logprob": -0.3671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " step"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 279,
|
||||||
"logprob": -1.6953125,
|
"logprob": -0.88671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1222,
|
"id": 254,
|
||||||
"logprob": -1.75,
|
"logprob": -0.69140625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " used"
|
"text": " the"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nThe test request is a document that is used"
|
"generated_text": "\nThe test request is the first step in the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -215,56 +215,56 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1727,
|
"id": 1727,
|
||||||
"logprob": -2.359375,
|
"logprob": -2.4375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3102,
|
"id": 3102,
|
||||||
"logprob": -0.83203125,
|
"logprob": -0.83984375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 317,
|
||||||
"logprob": -1.125,
|
"logprob": -1.1328125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 245,
|
"id": 254,
|
||||||
"logprob": -1.5703125,
|
"logprob": -1.515625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3412,
|
"id": 1022,
|
||||||
"logprob": -2.578125,
|
"logprob": -1.15625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " document"
|
"text": " first"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 344,
|
"id": 3458,
|
||||||
"logprob": -1.125,
|
"logprob": -0.3671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " step"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 279,
|
||||||
"logprob": -1.6953125,
|
"logprob": -0.88671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1222,
|
"id": 254,
|
||||||
"logprob": -1.75,
|
"logprob": -0.69140625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " used"
|
"text": " the"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nThe test request is a document that is used"
|
"generated_text": "\nThe test request is the first step in the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -304,55 +304,55 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1727,
|
"id": 1727,
|
||||||
"logprob": -2.359375,
|
"logprob": -2.4375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3102,
|
"id": 3102,
|
||||||
"logprob": -0.83203125,
|
"logprob": -0.83984375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 317,
|
||||||
"logprob": -1.125,
|
"logprob": -1.1328125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 245,
|
"id": 254,
|
||||||
"logprob": -1.5703125,
|
"logprob": -1.515625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3412,
|
"id": 1022,
|
||||||
"logprob": -2.578125,
|
"logprob": -1.15625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " document"
|
"text": " first"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 344,
|
"id": 3458,
|
||||||
"logprob": -1.125,
|
"logprob": -0.3671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " step"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 317,
|
"id": 279,
|
||||||
"logprob": -1.6953125,
|
"logprob": -0.88671875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " in"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1222,
|
"id": 254,
|
||||||
"logprob": -1.75,
|
"logprob": -0.69140625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " used"
|
"text": " the"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\nThe test request is a document that is used"
|
"generated_text": "\nThe test request is the first step in the"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "stop_sequence",
|
||||||
"generated_tokens": 10,
|
"generated_tokens": 5,
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 128000,
|
"id": 128000,
|
||||||
|
@ -16,7 +16,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1715,
|
"id": 1715,
|
||||||
"logprob": -10.375,
|
"logprob": -10.4375,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -29,61 +29,31 @@
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2209,
|
"id": 923,
|
||||||
"logprob": -2.78125,
|
"logprob": -2.84375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Is"
|
"text": " add"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 279,
|
"id": 264,
|
||||||
"logprob": -0.6328125,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " a"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 734,
|
|
||||||
"logprob": -2.703125,
|
|
||||||
"special": false,
|
|
||||||
"text": " function"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 330,
|
"id": 330,
|
||||||
"logprob": -0.34179688,
|
"logprob": -0.31640625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " \""
|
"text": " \""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4110,
|
"id": 1985,
|
||||||
"logprob": -2.359375,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Create"
|
"text": "test"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 7575,
|
|
||||||
"logprob": -2.1875,
|
|
||||||
"special": false,
|
|
||||||
"text": "Process"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"logprob": -0.07910156,
|
|
||||||
"special": false,
|
|
||||||
"text": "\""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 304,
|
|
||||||
"logprob": -0.83203125,
|
|
||||||
"special": false,
|
|
||||||
"text": " in"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 12468,
|
|
||||||
"logprob": -1.8203125,
|
|
||||||
"special": false,
|
|
||||||
"text": " Win"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
|
"generated_text": "Test request: add a \"test"
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 100,
|
"id": 100,
|
||||||
"logprob": -0.38549805,
|
"logprob": -0.38305664,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -29,7 +29,7 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 2284,
|
"id": 2284,
|
||||||
"logprob": -0.31323242,
|
"logprob": -0.296875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "():"
|
"text": "():"
|
||||||
},
|
},
|
||||||
|
@ -59,19 +59,19 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10914,
|
"id": 10914,
|
||||||
"logprob": -0.7817383,
|
"logprob": -0.7734375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " World"
|
"text": " World"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16013,
|
"id": 16013,
|
||||||
"logprob": -0.6328125,
|
"logprob": -0.61816406,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "!\")"
|
"text": "!\")"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 222,
|
"id": 222,
|
||||||
"logprob": -0.0619812,
|
"logprob": -0.054870605,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
|
@ -83,7 +83,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 610,
|
"id": 610,
|
||||||
"logprob": -0.4086914,
|
"logprob": -0.4152832,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "def"
|
"text": "def"
|
||||||
},
|
},
|
||||||
|
@ -113,7 +113,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 444,
|
"id": 444,
|
||||||
"logprob": -0.21826172,
|
"logprob": -0.21618652,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "name"
|
"text": "name"
|
||||||
},
|
},
|
||||||
|
@ -173,7 +173,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11571,
|
"id": 11571,
|
||||||
"logprob": -0.10021973,
|
"logprob": -0.08892822,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "!\""
|
"text": "!\""
|
||||||
},
|
},
|
||||||
|
|
|
@ -30,19 +30,19 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 264,
|
||||||
"logprob": -0.37573242,
|
"logprob": -0.38061523,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 633,
|
"id": 633,
|
||||||
"logprob": -0.09161377,
|
"logprob": -0.09301758,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " new"
|
"text": " new"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4480,
|
"id": 4480,
|
||||||
"logprob": -0.26171875,
|
"logprob": -0.26782227,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " feature"
|
"text": " feature"
|
||||||
},
|
},
|
||||||
|
@ -78,7 +78,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": 0.0,
|
"logprob": -0.10632324,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
|
||||||
print(repr(response.choices[0].message.content))
|
print(repr(response.choices[0].message.content))
|
||||||
assert (
|
assert (
|
||||||
response.choices[0].message.content
|
response.choices[0].message.content
|
||||||
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas"
|
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C"
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
|
@ -36,6 +36,7 @@ tools = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format"],
|
"required": ["location", "format"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -62,13 +63,13 @@ tools = [
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["location", "format", "num_days"],
|
"required": ["location", "format", "num_days"],
|
||||||
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||||
|
@ -76,7 +77,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
presence_penalty=-1.1,
|
temperature=0.0,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_auto(
|
async def test_flash_llama_grammar_tools_auto(
|
||||||
|
@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -129,12 +129,12 @@ async def test_flash_llama_grammar_tools_auto(
|
||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -142,7 +142,6 @@ async def test_flash_llama_grammar_tools_auto(
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_choice(
|
async def test_flash_llama_grammar_tools_choice(
|
||||||
|
@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -168,12 +167,12 @@ async def test_flash_llama_grammar_tools_choice(
|
||||||
assert response.choices[0].message.content is None
|
assert response.choices[0].message.content is None
|
||||||
assert response.choices[0].message.tool_calls == [
|
assert response.choices[0].message.tool_calls == [
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": "0",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"description": None,
|
"description": None,
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
"arguments": {"format": "celsius", "location": "New York, NY"},
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -181,7 +180,6 @@ async def test_flash_llama_grammar_tools_choice(
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_stream(
|
async def test_flash_llama_grammar_tools_stream(
|
||||||
|
@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=1,
|
seed=1,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
tool_choice="get_current_weather",
|
tool_choice="get_current_weather",
|
||||||
presence_penalty=-1.1,
|
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -210,11 +208,10 @@ async def test_flash_llama_grammar_tools_stream(
|
||||||
async for response in responses:
|
async for response in responses:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
assert count == 38
|
assert count == 48
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Takes too long to run")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_grammar_tools_insufficient_information(
|
async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
|
@ -222,13 +219,13 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
):
|
):
|
||||||
responses = await flash_llama_grammar_tools.chat(
|
responses = await flash_llama_grammar_tools.chat(
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
seed=8,
|
seed=24,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
|
||||||
)
|
)
|
||||||
|
|
||||||
assert responses.choices[0].message.content is None
|
assert responses.choices[0].message.content is None
|
||||||
assert responses.choices[0].message.tool_calls == [
|
assert (
|
||||||
{
|
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
||||||
"function": {
|
)
|
||||||
"arguments": {
|
|
||||||
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
|
|
||||||
},
|
|
||||||
"description": None,
|
|
||||||
"name": "notify_error",
|
|
||||||
},
|
|
||||||
"id": 0,
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohttp"
|
name = "aiohttp"
|
||||||
|
@ -268,16 +268,6 @@ files = [
|
||||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "colored"
|
|
||||||
version = "1.4.4"
|
|
||||||
description = "Simple library for color and formatting to terminal"
|
|
||||||
optional = false
|
|
||||||
python-versions = "*"
|
|
||||||
files = [
|
|
||||||
{file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "docker"
|
name = "docker"
|
||||||
version = "6.1.3"
|
version = "6.1.3"
|
||||||
|
@ -855,18 +845,17 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syrupy"
|
name = "syrupy"
|
||||||
version = "4.0.1"
|
version = "4.7.1"
|
||||||
description = "Pytest Snapshot Test Utility"
|
description = "Pytest Snapshot Test Utility"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4"
|
python-versions = ">=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "syrupy-4.0.1-py3-none-any.whl", hash = "sha256:53d3107cc5e18a5def189c721879cea2cdafdee34b879f602133ca08837d0e4b"},
|
{file = "syrupy-4.7.1-py3-none-any.whl", hash = "sha256:be002267a512a4bedddfae2e026c93df1ea928ae10baadc09640516923376d41"},
|
||||||
{file = "syrupy-4.0.1.tar.gz", hash = "sha256:60e3e94782444e0f978cd3b207de32f6da3199b15a2db32eab02f83cebb63ae8"},
|
{file = "syrupy-4.7.1.tar.gz", hash = "sha256:f9d4485f3f27d0e5df6ed299cac6fa32eb40a441915d988e82be5a4bdda335c8"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
colored = ">=1.3.92,<2.0.0"
|
pytest = ">=7.0.0,<9.0.0"
|
||||||
pytest = ">=7.0.0,<8.0.0"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation"
|
name = "text-generation"
|
||||||
|
@ -1049,4 +1038,4 @@ multidict = ">=4.0"
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<3.13"
|
python-versions = ">=3.9,<3.13"
|
||||||
content-hash = "421fbce065cb1499c666599cf0fd83a5ce8fb3bed09e83c16c3a3d6953b34026"
|
content-hash = "f5c65e704b02250d73055cd04efcc22f8fc36144eddfc447a71c3a061748db80"
|
||||||
|
|
|
@ -7,7 +7,7 @@ authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
pydantic = "> 2, < 3"
|
pydantic = "> 2, < 3"
|
||||||
python = ">=3.9,<3.13"
|
python = ">=3.9,<3.13"
|
||||||
syrupy = "4.0.1"
|
syrupy = "^4.7.1"
|
||||||
text-generation = "^0.6.0"
|
text-generation = "^0.6.0"
|
||||||
pytest = "^7.4.0"
|
pytest = "^7.4.0"
|
||||||
pytest-asyncio = "^0.21.1"
|
pytest-asyncio = "^0.21.1"
|
||||||
|
|
|
@ -6,7 +6,6 @@ attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||||
colored==1.4.4 ; python_version >= "3.9" and python_version < "3.13"
|
|
||||||
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
docker==6.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
|
exceptiongroup==1.1.3 ; python_version >= "3.9" and python_version < "3.11"
|
||||||
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -25,7 +24,7 @@ pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
syrupy==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
|
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
|
|
@ -8,7 +8,7 @@ use nix::unistd::Pid;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::OsString;
|
use std::ffi::OsString;
|
||||||
use std::io::{BufRead, BufReader, Lines};
|
use std::io::{BufRead, BufReader};
|
||||||
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
use std::os::unix::process::{CommandExt, ExitStatusExt};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::process::{Child, Command, ExitStatus, Stdio};
|
use std::process::{Child, Command, ExitStatus, Stdio};
|
||||||
|
@ -18,12 +18,103 @@ use std::sync::{mpsc, Arc};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{fs, io};
|
use std::{
|
||||||
|
fs, io,
|
||||||
|
io::{Read, Write},
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
||||||
|
|
||||||
mod env_runtime;
|
mod env_runtime;
|
||||||
|
|
||||||
|
fn get_config(
|
||||||
|
model_id: &str,
|
||||||
|
revision: &Option<String>,
|
||||||
|
) -> Result<Config, Box<dyn std::error::Error>> {
|
||||||
|
let mut path = std::path::Path::new(model_id).to_path_buf();
|
||||||
|
let model_id = model_id.to_string();
|
||||||
|
let filename = if !path.exists() {
|
||||||
|
// Assume it's a hub id
|
||||||
|
|
||||||
|
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
||||||
|
// env variable has precedence over on file token.
|
||||||
|
ApiBuilder::new().with_token(Some(token)).build()?
|
||||||
|
} else {
|
||||||
|
Api::new()?
|
||||||
|
};
|
||||||
|
let repo = if let Some(ref revision) = revision {
|
||||||
|
api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
api.model(model_id)
|
||||||
|
};
|
||||||
|
repo.get("config.json")?
|
||||||
|
} else {
|
||||||
|
path.push("config.json");
|
||||||
|
path
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = std::fs::read_to_string(filename)?;
|
||||||
|
let config: RawConfig = serde_json::from_str(&content)?;
|
||||||
|
|
||||||
|
let config: Config = config.into();
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
|
||||||
|
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
|
||||||
|
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
|
||||||
|
if let Some(config) = config {
|
||||||
|
if prefix_caching.is_none() {
|
||||||
|
if config.vision_config.is_some() {
|
||||||
|
tracing::info!("Disabling prefix caching because of VLM model");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
|
} else if config.is_encoder_decoder {
|
||||||
|
tracing::info!("Disabling prefix caching because of seq2seq model");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match config.head_dim {
|
||||||
|
Some(h) if h == 64 || h == 128 || h == 256 => {
|
||||||
|
if lora_adapters.is_some() && prefix_caching.is_none() {
|
||||||
|
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
|
}
|
||||||
|
match config.model_type.as_deref() {
|
||||||
|
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||||
|
// Required because gemma2 needs bfloat16 which is not supported by
|
||||||
|
// flashinfer ?
|
||||||
|
if attention.is_none() {
|
||||||
|
tracing::info!(
|
||||||
|
"Forcing flash decoding because model {} requires it",
|
||||||
|
config.model_type.as_ref().unwrap()
|
||||||
|
);
|
||||||
|
attention = Some("flashdecoding".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some("t5") => {}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
if attention.is_none() {
|
||||||
|
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||||
|
attention = Some("flashdecoding".to_string());
|
||||||
|
}
|
||||||
|
if prefix_caching.is_none() {
|
||||||
|
prefix_caching = Some("0".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
|
||||||
|
let attention = attention.unwrap_or("flashinfer".to_string());
|
||||||
|
(prefix_caching, attention)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct RawConfig {
|
struct RawConfig {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
|
@ -31,6 +122,12 @@ struct RawConfig {
|
||||||
model_type: Option<String>,
|
model_type: Option<String>,
|
||||||
max_seq_len: Option<usize>,
|
max_seq_len: Option<usize>,
|
||||||
quantization_config: Option<QuantizationConfig>,
|
quantization_config: Option<QuantizationConfig>,
|
||||||
|
n_embd: Option<usize>,
|
||||||
|
hidden_size: Option<usize>,
|
||||||
|
num_attention_heads: Option<usize>,
|
||||||
|
head_dim: Option<usize>,
|
||||||
|
vision_config: Option<VisionConfig>,
|
||||||
|
is_encoder_decoder: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -38,10 +135,17 @@ struct QuantizationConfig {
|
||||||
quant_method: Option<Quantization>,
|
quant_method: Option<Quantization>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct VisionConfig {}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
max_position_embeddings: Option<usize>,
|
max_position_embeddings: Option<usize>,
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
head_dim: Option<usize>,
|
||||||
|
model_type: Option<String>,
|
||||||
|
vision_config: Option<VisionConfig>,
|
||||||
|
is_encoder_decoder: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<RawConfig> for Config {
|
impl From<RawConfig> for Config {
|
||||||
|
@ -51,9 +155,32 @@ impl From<RawConfig> for Config {
|
||||||
.or(other.max_seq_len)
|
.or(other.max_seq_len)
|
||||||
.or(other.n_positions);
|
.or(other.n_positions);
|
||||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||||
|
let head_dim = other.head_dim.or_else(|| {
|
||||||
|
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
||||||
|
(Some(hidden_size), _, Some(num_attention_heads))
|
||||||
|
if hidden_size % num_attention_heads == 0 =>
|
||||||
|
{
|
||||||
|
Some(hidden_size / num_attention_heads)
|
||||||
|
}
|
||||||
|
// Legacy
|
||||||
|
(_, Some(hidden_size), Some(num_attention_heads))
|
||||||
|
if hidden_size % num_attention_heads == 0 =>
|
||||||
|
{
|
||||||
|
Some(hidden_size / num_attention_heads)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let model_type = other.model_type;
|
||||||
|
let vision_config = other.vision_config;
|
||||||
|
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||||
Config {
|
Config {
|
||||||
max_position_embeddings,
|
max_position_embeddings,
|
||||||
quantize,
|
quantize,
|
||||||
|
head_dim,
|
||||||
|
model_type,
|
||||||
|
vision_config,
|
||||||
|
is_encoder_decoder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -731,6 +858,7 @@ fn shard_manager(
|
||||||
.args(shard_args)
|
.args(shard_args)
|
||||||
.env_clear()
|
.env_clear()
|
||||||
.envs(envs)
|
.envs(envs)
|
||||||
|
.stdin(Stdio::piped())
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.process_group(0)
|
.process_group(0)
|
||||||
|
@ -752,12 +880,13 @@ fn shard_manager(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Redirect STDOUT to the console
|
// Redirect STDOUT to the console
|
||||||
|
let mut pstdin = p.stdin.take().unwrap();
|
||||||
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
|
||||||
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
|
||||||
|
|
||||||
//stdout tracing thread
|
//stdout tracing thread
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
log_lines(shard_stdout_reader.lines());
|
log_lines(shard_stdout_reader);
|
||||||
});
|
});
|
||||||
// We read stderr in another thread as it seems that lines() can block in some cases
|
// We read stderr in another thread as it seems that lines() can block in some cases
|
||||||
let (err_sender, err_receiver) = mpsc::channel();
|
let (err_sender, err_receiver) = mpsc::channel();
|
||||||
|
@ -766,6 +895,18 @@ fn shard_manager(
|
||||||
err_sender.send(line).unwrap_or(());
|
err_sender.send(line).unwrap_or(());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
// We read stdin in another thread as it seems that lines() can block in some cases
|
||||||
|
thread::spawn(move || {
|
||||||
|
let mut stdin = io::stdin(); // We get `Stdin` here.
|
||||||
|
loop {
|
||||||
|
let mut buffer = vec![0; 4096];
|
||||||
|
if let Ok(n) = stdin.read(&mut buffer) {
|
||||||
|
if n > 0 {
|
||||||
|
let _ = pstdin.write_all(&buffer[..n]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut ready = false;
|
let mut ready = false;
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
@ -872,19 +1013,36 @@ impl PythonLogMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<&String> for PythonLogMessage {
|
impl TryFrom<&[u8]> for PythonLogMessage {
|
||||||
type Error = serde_json::Error;
|
type Error = serde_json::Error;
|
||||||
|
|
||||||
fn try_from(value: &String) -> Result<Self, Self::Error> {
|
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
|
||||||
serde_json::from_str::<Self>(value)
|
serde_json::from_slice::<Self>(value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
|
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||||
for line in lines.map_while(Result::ok) {
|
let mut buffer = vec![0u8; 8 * 4096];
|
||||||
match PythonLogMessage::try_from(&line) {
|
let mut stdout = std::io::stdout();
|
||||||
Ok(log) => log.trace(),
|
loop {
|
||||||
Err(_) => tracing::debug!("{line}"),
|
let n = bufread.read(&mut buffer);
|
||||||
|
if let Ok(n) = n {
|
||||||
|
if n > 0 {
|
||||||
|
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
|
||||||
|
while let Some(line) = lines.next() {
|
||||||
|
match PythonLogMessage::try_from(line) {
|
||||||
|
Ok(log) => log.trace(),
|
||||||
|
// For interactive debugging ?
|
||||||
|
Err(_) => {
|
||||||
|
stdout.write_all(line).unwrap();
|
||||||
|
if lines.peek().is_some() {
|
||||||
|
stdout.write_all(b"\n").unwrap();
|
||||||
|
}
|
||||||
|
stdout.flush().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1044,7 +1202,7 @@ fn download_convert_model(
|
||||||
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
|
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
|
||||||
|
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
log_lines(download_stdout.lines());
|
log_lines(download_stdout);
|
||||||
});
|
});
|
||||||
|
|
||||||
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
|
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
|
||||||
|
@ -1439,68 +1597,35 @@ fn main() -> Result<(), LauncherError> {
|
||||||
|
|
||||||
tracing::info!("{:#?}", args);
|
tracing::info!("{:#?}", args);
|
||||||
|
|
||||||
let get_max_positions_quantize =
|
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||||
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
|
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||||
let model_id = args.model_id.clone();
|
// Quantization usually means you're even more RAM constrained.
|
||||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
let max_default = 4096;
|
||||||
let filename = if !path.exists() {
|
|
||||||
// Assume it's a hub id
|
|
||||||
|
|
||||||
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
|
let max_position_embeddings = if let Some(config) = &config {
|
||||||
// env variable has precedence over on file token.
|
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||||
ApiBuilder::new().with_token(Some(token)).build()?
|
if max_position_embeddings > max_default {
|
||||||
} else {
|
let max = max_position_embeddings;
|
||||||
Api::new()?
|
if args.max_input_tokens.is_none()
|
||||||
};
|
&& args.max_total_tokens.is_none()
|
||||||
let repo = if let Some(ref revision) = args.revision {
|
&& args.max_batch_prefill_tokens.is_none()
|
||||||
api.repo(Repo::with_revision(
|
{
|
||||||
model_id,
|
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||||
RepoType::Model,
|
|
||||||
revision.to_string(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
api.model(model_id)
|
|
||||||
};
|
|
||||||
repo.get("config.json")?
|
|
||||||
} else {
|
|
||||||
path.push("config.json");
|
|
||||||
path
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = std::fs::read_to_string(filename)?;
|
|
||||||
let config: RawConfig = serde_json::from_str(&content)?;
|
|
||||||
|
|
||||||
if config.model_type == Some("gemma2".to_string()) {
|
|
||||||
tracing::info!("Forcing flash decoding because of softcap usage");
|
|
||||||
std::env::set_var("ATTENTION", "flashdecoding");
|
|
||||||
}
|
|
||||||
let config: Config = config.into();
|
|
||||||
let quantize = config.quantize;
|
|
||||||
|
|
||||||
// Quantization usually means you're even more RAM constrained.
|
|
||||||
let max_default = 4096;
|
|
||||||
|
|
||||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
|
||||||
if max_position_embeddings > max_default {
|
|
||||||
let max = max_position_embeddings;
|
|
||||||
if args.max_input_tokens.is_none()
|
|
||||||
&& args.max_total_tokens.is_none()
|
|
||||||
&& args.max_batch_prefill_tokens.is_none()
|
|
||||||
{
|
|
||||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
|
||||||
}
|
|
||||||
Ok((max_default, quantize))
|
|
||||||
} else {
|
|
||||||
Ok((max_position_embeddings, quantize))
|
|
||||||
}
|
}
|
||||||
|
max_default
|
||||||
} else {
|
} else {
|
||||||
Err(Box::new(LauncherError::ArgumentValidation(
|
max_position_embeddings
|
||||||
"no max defined".to_string(),
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
};
|
} else {
|
||||||
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
|
max_default
|
||||||
get_max_positions_quantize().unwrap_or((4096, None));
|
}
|
||||||
|
} else {
|
||||||
|
max_default
|
||||||
|
};
|
||||||
|
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||||
|
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||||
|
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
|
||||||
|
std::env::set_var("ATTENTION", attention);
|
||||||
|
|
||||||
let max_input_tokens = {
|
let max_input_tokens = {
|
||||||
match (args.max_input_tokens, args.max_input_length) {
|
match (args.max_input_tokens, args.max_input_length) {
|
||||||
|
|
|
@ -33,13 +33,13 @@ export function get_options() {
|
||||||
// rate: 20,
|
// rate: 20,
|
||||||
// timeUnit: '1s',
|
// timeUnit: '1s',
|
||||||
// },
|
// },
|
||||||
load_test: {
|
// load_test: {
|
||||||
executor: 'constant-arrival-rate',
|
// executor: 'constant-arrival-rate',
|
||||||
duration: '60s',
|
// duration: '60s',
|
||||||
preAllocatedVUs: 100,
|
// preAllocatedVUs: 100,
|
||||||
rate: 1,
|
// rate: 1,
|
||||||
timeUnit: '1s',
|
// timeUnit: '1s',
|
||||||
},
|
// },
|
||||||
// breakpoint: {
|
// breakpoint: {
|
||||||
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
|
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
|
||||||
// preAllocatedVUs: 300,
|
// preAllocatedVUs: 300,
|
||||||
|
@ -47,12 +47,12 @@ export function get_options() {
|
||||||
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
|
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
|
||||||
// ],
|
// ],
|
||||||
// },
|
// },
|
||||||
// throughput: {
|
throughput: {
|
||||||
// executor: 'shared-iterations',
|
executor: 'shared-iterations',
|
||||||
// vus: 100,
|
vus: 100,
|
||||||
// iterations: 200,
|
iterations: 200,
|
||||||
// maxDuration: '40s',
|
maxDuration: '40s',
|
||||||
// },
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,34 +20,52 @@ defaultCrateOverrides
|
||||||
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = "-C target-feature=-crt-static"; };
|
||||||
|
|
||||||
grpc-metadata = attrs: {
|
grpc-metadata = attrs: {
|
||||||
src =
|
src = filter {
|
||||||
filter {
|
root = ../backends/grpc-metadata;
|
||||||
root = ../backends/grpc-metadata;
|
include = with filter; [
|
||||||
include = with filter; [
|
isDirectory
|
||||||
isDirectory
|
(matchExt "rs")
|
||||||
(matchExt "rs")
|
];
|
||||||
];
|
};
|
||||||
};
|
|
||||||
};
|
};
|
||||||
text-generation-launcer = attrs: {
|
text-generation-benchmark = attrs: {
|
||||||
src =
|
src = filter {
|
||||||
filter {
|
root = ../benchmark;
|
||||||
root = ../launcher;
|
include = with filter; [
|
||||||
include = with filter; [
|
isDirectory
|
||||||
isDirectory
|
(matchExt "rs")
|
||||||
(matchExt "rs")
|
];
|
||||||
];
|
};
|
||||||
};
|
};
|
||||||
|
text-generation-client = attrs: {
|
||||||
|
src = filter {
|
||||||
|
root = ../.;
|
||||||
|
include = with filter; [
|
||||||
|
isDirectory
|
||||||
|
(and (inDirectory "backends/client") (matchExt "rs"))
|
||||||
|
(and (inDirectory "proto") (matchExt "proto"))
|
||||||
|
];
|
||||||
|
};
|
||||||
|
postPatch = "cd backends/client";
|
||||||
|
buildInputs = [ protobuf ];
|
||||||
|
};
|
||||||
|
text-generation-launcher = attrs: {
|
||||||
|
src = filter {
|
||||||
|
root = ../launcher;
|
||||||
|
include = with filter; [
|
||||||
|
isDirectory
|
||||||
|
(matchExt "rs")
|
||||||
|
];
|
||||||
|
};
|
||||||
};
|
};
|
||||||
text-generation-router = attrs: {
|
text-generation-router = attrs: {
|
||||||
src =
|
src = filter {
|
||||||
filter {
|
root = ../router;
|
||||||
root = ../router;
|
include = with filter; [
|
||||||
include = with filter; [
|
isDirectory
|
||||||
isDirectory
|
(matchExt "rs")
|
||||||
(matchExt "rs")
|
];
|
||||||
];
|
};
|
||||||
};
|
|
||||||
};
|
};
|
||||||
text-generation-router-v3 = attrs: {
|
text-generation-router-v3 = attrs: {
|
||||||
# We need to do the src/source root dance so that the build
|
# We need to do the src/source root dance so that the build
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
{
|
||||||
|
nix-filter,
|
||||||
|
buildPythonPackage,
|
||||||
|
poetry-core,
|
||||||
|
mypy-protobuf,
|
||||||
|
awq-inference-engine,
|
||||||
|
causal-conv1d,
|
||||||
|
eetq,
|
||||||
|
einops,
|
||||||
|
exllamav2,
|
||||||
|
fbgemm-gpu,
|
||||||
|
flashinfer,
|
||||||
|
flash-attn,
|
||||||
|
flash-attn-layer-norm,
|
||||||
|
flash-attn-rotary,
|
||||||
|
grpc-interceptor,
|
||||||
|
grpcio-reflection,
|
||||||
|
grpcio-status,
|
||||||
|
grpcio-tools,
|
||||||
|
hf-transfer,
|
||||||
|
loguru,
|
||||||
|
mamba-ssm,
|
||||||
|
marlin-kernels,
|
||||||
|
opentelemetry-api,
|
||||||
|
opentelemetry-exporter-otlp,
|
||||||
|
opentelemetry-instrumentation-grpc,
|
||||||
|
opentelemetry-semantic-conventions,
|
||||||
|
peft,
|
||||||
|
punica-kernels,
|
||||||
|
safetensors,
|
||||||
|
tokenizers,
|
||||||
|
torch,
|
||||||
|
sentencepiece,
|
||||||
|
transformers,
|
||||||
|
typer,
|
||||||
|
vllm,
|
||||||
|
}:
|
||||||
|
|
||||||
|
let
|
||||||
|
filter = nix-filter.lib;
|
||||||
|
in
|
||||||
|
buildPythonPackage {
|
||||||
|
name = "text-generation-server";
|
||||||
|
|
||||||
|
src = filter {
|
||||||
|
root = ../.;
|
||||||
|
include = with filter; [
|
||||||
|
isDirectory
|
||||||
|
(and (inDirectory "server") (or_ (matchExt "py") (matchExt "pyi")))
|
||||||
|
"server/pyproject.toml"
|
||||||
|
(and (inDirectory "proto/v3") (matchExt "proto"))
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
pyproject = true;
|
||||||
|
|
||||||
|
build-system = [ poetry-core ];
|
||||||
|
|
||||||
|
nativeBuildInputs = [ mypy-protobuf ];
|
||||||
|
|
||||||
|
pythonRelaxDeps = [
|
||||||
|
"einops"
|
||||||
|
"huggingface-hub"
|
||||||
|
"loguru"
|
||||||
|
"opentelemetry-instrumentation-grpc"
|
||||||
|
"sentencepiece"
|
||||||
|
"typer"
|
||||||
|
];
|
||||||
|
|
||||||
|
pythonRemoveDeps = [ "scipy" ];
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
awq-inference-engine
|
||||||
|
eetq
|
||||||
|
causal-conv1d
|
||||||
|
einops
|
||||||
|
exllamav2
|
||||||
|
fbgemm-gpu
|
||||||
|
flashinfer
|
||||||
|
flash-attn
|
||||||
|
flash-attn-layer-norm
|
||||||
|
flash-attn-rotary
|
||||||
|
grpc-interceptor
|
||||||
|
grpcio-reflection
|
||||||
|
grpcio-status
|
||||||
|
grpcio-tools
|
||||||
|
hf-transfer
|
||||||
|
loguru
|
||||||
|
mamba-ssm
|
||||||
|
marlin-kernels
|
||||||
|
opentelemetry-api
|
||||||
|
opentelemetry-exporter-otlp
|
||||||
|
opentelemetry-instrumentation-grpc
|
||||||
|
opentelemetry-semantic-conventions
|
||||||
|
peft
|
||||||
|
punica-kernels
|
||||||
|
safetensors
|
||||||
|
sentencepiece
|
||||||
|
tokenizers
|
||||||
|
transformers
|
||||||
|
typer
|
||||||
|
vllm
|
||||||
|
];
|
||||||
|
|
||||||
|
prePatch = ''
|
||||||
|
python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \
|
||||||
|
--grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto
|
||||||
|
find server/text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
|
touch server/text_generation_server/pb/__init__.py
|
||||||
|
cd server
|
||||||
|
'';
|
||||||
|
}
|
|
@ -137,6 +137,8 @@ message Request {
|
||||||
optional string adapter_id = 11;
|
optional string adapter_id = 11;
|
||||||
/// Prefix length that can be retrieved from the KV cache.
|
/// Prefix length that can be retrieved from the KV cache.
|
||||||
uint32 prefix_len = 12;
|
uint32 prefix_len = 12;
|
||||||
|
/// Context truncation
|
||||||
|
bool add_special_tokens = 13;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
|
|
@ -46,8 +46,8 @@ ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||||
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
] }
|
] }
|
||||||
minijinja = { version = "2.0.2" }
|
minijinja = { workspace = true }
|
||||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
minijinja-contrib = { workspace = true }
|
||||||
futures-util = "0.3.30"
|
futures-util = "0.3.30"
|
||||||
regex = "1.10.3"
|
regex = "1.10.3"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
use std::collections::HashSet;
|
|
||||||
|
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{
|
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
|
||||||
};
|
|
||||||
use minijinja::{Environment, ErrorKind, Template};
|
use minijinja::{Environment, ErrorKind, Template};
|
||||||
use minijinja_contrib::pycompat;
|
use minijinja_contrib::pycompat;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
/// Raise a exception (custom function) used in the chat templates
|
/// Raise a exception (custom function) used in the chat templates
|
||||||
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||||
|
@ -32,6 +29,7 @@ impl ChatTemplate {
|
||||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||||
let template_str = template.into_boxed_str();
|
let template_str = template.into_boxed_str();
|
||||||
env.add_function("raise_exception", raise_exception);
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
tracing::debug!("Loading template: {:#?}", template_str);
|
||||||
|
|
||||||
// leaking env and template_str as read-only, static resources for performance.
|
// leaking env and template_str as read-only, static resources for performance.
|
||||||
let template = Box::leak(env)
|
let template = Box::leak(env)
|
||||||
|
@ -42,6 +40,7 @@ impl ChatTemplate {
|
||||||
let variables = template.undeclared_variables(true);
|
let variables = template.undeclared_variables(true);
|
||||||
// check if the `tools` variable is used in the template
|
// check if the `tools` variable is used in the template
|
||||||
let use_default_tool_template = !variables.contains("tools");
|
let use_default_tool_template = !variables.contains("tools");
|
||||||
|
tracing::debug!("Use default tool template: {}", use_default_tool_template);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
template,
|
template,
|
||||||
|
@ -56,25 +55,36 @@ impl ChatTemplate {
|
||||||
&self,
|
&self,
|
||||||
guideline: Option<&str>,
|
guideline: Option<&str>,
|
||||||
mut messages: Vec<Message>,
|
mut messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
if self.use_default_tool_template {
|
|
||||||
if let Some(last_message) = messages.last_mut() {
|
|
||||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
|
||||||
last_message.content.push(MessageChunk::Text {
|
|
||||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
|
||||||
|
|
||||||
// check if guideline is expected but not provided
|
// check if guideline is expected but not provided
|
||||||
if self.variables.contains("guideline") && guideline.is_none() {
|
if self.variables.contains("guideline") && guideline.is_none() {
|
||||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let tools = match tools_and_prompt {
|
||||||
|
Some((tools, tool_prompt)) => {
|
||||||
|
// check if the `tools` variable is used in the template
|
||||||
|
// if not, we need to append the tools to the last message
|
||||||
|
let text = if self.use_default_tool_template {
|
||||||
|
match serde_json::to_string(&tools) {
|
||||||
|
Ok(tools_str) => format!("\n---\n{}\n{}", tools_str, tool_prompt),
|
||||||
|
Err(e) => return Err(InferError::ToolError(e.to_string())),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// if the `tools` variable is used in the template, we just append the tool_prompt
|
||||||
|
format!("\n---\n{}", tool_prompt)
|
||||||
|
};
|
||||||
|
if let Some(last_message) = messages.last_mut() {
|
||||||
|
last_message.content.push(MessageChunk::Text { text });
|
||||||
|
}
|
||||||
|
Some(tools)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||||
|
|
||||||
self.template
|
self.template
|
||||||
.render(ChatTemplateInputs {
|
.render(ChatTemplateInputs {
|
||||||
guideline,
|
guideline,
|
||||||
|
@ -82,8 +92,7 @@ impl ChatTemplate {
|
||||||
bos_token: self.bos_token.as_deref(),
|
bos_token: self.bos_token.as_deref(),
|
||||||
eos_token: self.eos_token.as_deref(),
|
eos_token: self.eos_token.as_deref(),
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
tools: None,
|
tools,
|
||||||
tools_prompt: None,
|
|
||||||
})
|
})
|
||||||
.map_err(InferError::TemplateError)
|
.map_err(InferError::TemplateError)
|
||||||
}
|
}
|
||||||
|
@ -95,7 +104,7 @@ mod tests {
|
||||||
use crate::infer::chat_template::raise_exception;
|
use crate::infer::chat_template::raise_exception;
|
||||||
use crate::infer::ChatTemplate;
|
use crate::infer::ChatTemplate;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken,
|
ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken, Tool,
|
||||||
};
|
};
|
||||||
use minijinja::Environment;
|
use minijinja::Environment;
|
||||||
|
|
||||||
|
@ -854,11 +863,46 @@ mod tests {
|
||||||
content: MessageContent::SingleText("Just testing".to_string()),
|
content: MessageContent::SingleText("Just testing".to_string()),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let tools = serde_json::json!("[]");
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
|
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||||
let tool_prompt = "This default prompt will be used".to_string();
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt);
|
let tools_and_prompt = Some((tools, tool_prompt));
|
||||||
let result = ct.apply(None, msgs, Some(grammer_with_prompt));
|
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string();
|
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||||
|
assert_eq!(result.unwrap(), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_with_custom_tool_template() {
|
||||||
|
// chat template from meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
|
let ct = ChatTemplate::new(
|
||||||
|
"{{- bos_token }}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n".to_string(),
|
||||||
|
Some(TokenizerConfigToken::String("<s>".to_string())),
|
||||||
|
Some(TokenizerConfigToken::String("</s>".to_string())),
|
||||||
|
);
|
||||||
|
let msgs: Vec<Message> = vec![
|
||||||
|
Message {
|
||||||
|
name: None,
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"Youre a helpful assistant! Answer the users question best you can."
|
||||||
|
.to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
name: None,
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"What is the weather like in Brooklyn, New York?".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||||
|
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||||
|
let tool_prompt = "This default prompt will be used".to_string();
|
||||||
|
let tools_and_prompt = Some((tools, tool_prompt));
|
||||||
|
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||||
|
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||||
assert_eq!(result.unwrap(), expected);
|
assert_eq!(result.unwrap(), expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ mod chat_template;
|
||||||
pub mod tool_grammar;
|
pub mod tool_grammar;
|
||||||
|
|
||||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||||
use crate::GrammarType;
|
use crate::Tool;
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
|
@ -120,10 +120,11 @@ impl Infer {
|
||||||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||||
// Tokenize request
|
// Tokenize request
|
||||||
let inputs = request.inputs;
|
let inputs = request.inputs;
|
||||||
|
let add_special_tokens = request.add_special_tokens;
|
||||||
let truncate = request.parameters.truncate;
|
let truncate = request.parameters.truncate;
|
||||||
let encoding = self
|
let encoding = self
|
||||||
.validation
|
.validation
|
||||||
.tokenize(inputs, truncate)
|
.tokenize(inputs, add_special_tokens, truncate)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
tracing::error!("Tokenization {err}");
|
tracing::error!("Tokenization {err}");
|
||||||
|
@ -140,12 +141,12 @@ impl Infer {
|
||||||
&self,
|
&self,
|
||||||
guideline: Option<String>,
|
guideline: Option<String>,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||||
) -> Result<String, InferError> {
|
) -> Result<String, InferError> {
|
||||||
self.chat_template
|
self.chat_template
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||||
.apply(guideline.as_deref(), messages, grammar_with_prompt)
|
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||||
tracing::error!("{e}");
|
tracing::error!("{e}");
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
use crate::infer::InferError;
|
use crate::infer::InferError;
|
||||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
use crate::{
|
||||||
|
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
|
||||||
|
ToolType,
|
||||||
|
};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
@ -16,17 +19,38 @@ impl ToolGrammar {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Vec<Tool>,
|
||||||
tool_choice: ToolChoice,
|
tool_choice: ToolChoice,
|
||||||
) -> Result<Option<Tools>, InferError> {
|
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
|
||||||
// if no tools are provided, we return None
|
// if no tools are provided, we return None
|
||||||
let tools = match tools {
|
if tools.is_empty() {
|
||||||
Some(tools) if !tools.is_empty() => tools,
|
return Ok((tools, None));
|
||||||
_ => return Ok(None),
|
}
|
||||||
};
|
|
||||||
|
|
||||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||||
|
|
||||||
|
let mut tools = tools.clone();
|
||||||
|
|
||||||
|
// add the notify_error function to the tools
|
||||||
|
let notify_error = Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "notify_error".to_string(),
|
||||||
|
description: Some("Notify an error or issue".to_string()),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"error": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The error or issue to notify"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["error"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
tools.push(notify_error);
|
||||||
|
|
||||||
// if tools are provided and no tool_choice we default to the OneOf
|
// if tools are provided and no tool_choice we default to the OneOf
|
||||||
let tools_to_use = match tool_choice {
|
let tools_to_use = match tool_choice {
|
||||||
ToolType::FunctionName(name) => {
|
ToolType::FunctionName(name) => {
|
||||||
|
@ -35,87 +59,57 @@ impl ToolGrammar {
|
||||||
ToolType::Function { function } => {
|
ToolType::Function { function } => {
|
||||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||||
}
|
}
|
||||||
ToolType::OneOf => tools,
|
ToolType::OneOf => tools.clone(),
|
||||||
ToolType::NoTool => return Ok(None),
|
ToolType::NoTool => return Ok((tools, None)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// adds the error notification function for LLM feedback if required
|
|
||||||
let mut text_response_properties = Map::new();
|
|
||||||
text_response_properties.insert(
|
|
||||||
"error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"description": "The error or issue to notify"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
text_response_properties.insert(
|
|
||||||
"_name".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "string",
|
|
||||||
"const": "notify_error"
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool| {
|
.map(|tool| {
|
||||||
let func = tool.function.clone();
|
let func = tool.function.clone();
|
||||||
|
|
||||||
// Clone the existing parameters, which are expected to be a JSON object
|
let mut params = Map::new();
|
||||||
let mut params = if let Value::Object(params) = &func.arguments {
|
|
||||||
params.clone()
|
|
||||||
} else {
|
|
||||||
Map::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Insert the function's description at the top level, outside of properties
|
|
||||||
params.insert(
|
params.insert(
|
||||||
"description".to_string(),
|
"description".to_string(),
|
||||||
Value::String(func.description.clone().unwrap_or_default()),
|
Value::String(func.description.unwrap_or_default()),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Ensure 'properties' exists and is an object
|
let mut properties = Map::new();
|
||||||
let properties = params
|
let mut required = vec![Value::String("_name".to_string())];
|
||||||
.entry("properties".to_string())
|
|
||||||
.or_insert_with(|| json!({}))
|
|
||||||
.as_object_mut()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Insert the constant for the function name inside 'properties'
|
|
||||||
properties.insert(
|
properties.insert(
|
||||||
"_name".to_string(),
|
"_name".to_string(),
|
||||||
json!({
|
json!({
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"const": func.name.clone(),
|
"const": func.name.clone(),
|
||||||
// "description": "The name of the function"
|
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
if let Value::Object(args) = func.arguments {
|
||||||
let required = params
|
if let Some(Value::Object(props)) = args.get("properties") {
|
||||||
.entry("required".to_string())
|
properties.extend(props.clone());
|
||||||
.or_insert_with(|| json!([]))
|
}
|
||||||
.as_array_mut()
|
if let Some(Value::Array(reqs)) = args.get("required") {
|
||||||
.unwrap();
|
required.extend(reqs.clone());
|
||||||
|
}
|
||||||
// Add 'name' to the 'required' array if it is not already present
|
params.insert(
|
||||||
if !required.iter().any(|r| r == "_name") {
|
"additionalProperties".to_string(),
|
||||||
required.push(json!("_name"));
|
Value::Bool(
|
||||||
|
args.get("additionalProperties").and_then(|v| v.as_str())
|
||||||
|
== Some("true"),
|
||||||
|
),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params.insert("properties".to_string(), Value::Object(properties));
|
||||||
|
params.insert("required".to_string(), Value::Array(required));
|
||||||
|
|
||||||
(func.name, Value::Object(params))
|
(func.name, Value::Object(params))
|
||||||
})
|
})
|
||||||
.chain([(
|
|
||||||
"notify_error".to_string(),
|
|
||||||
serde_json::json!({
|
|
||||||
"properties": text_response_properties,
|
|
||||||
"required": ["error", "_name"],
|
|
||||||
"type": "object"
|
|
||||||
}),
|
|
||||||
)])
|
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let tools = Tools {
|
let tool_schema = JsonSchemaTool {
|
||||||
functions_map: FunctionsMap { functions },
|
functions_map: FunctionsMap { functions },
|
||||||
properties: Properties {
|
properties: Properties {
|
||||||
function: tools_to_use
|
function: tools_to_use
|
||||||
|
@ -123,13 +117,10 @@ impl ToolGrammar {
|
||||||
.map(|tool| FunctionRef {
|
.map(|tool| FunctionRef {
|
||||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||||
})
|
})
|
||||||
.chain(std::iter::once(FunctionRef {
|
|
||||||
ref_path: "#/$functions/notify_error".to_string(),
|
|
||||||
}))
|
|
||||||
.collect(),
|
.collect(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(tools))
|
Ok((tools, Some(tool_schema)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,16 @@ pub enum Attention {
|
||||||
FlashInfer,
|
FlashInfer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
pub fn block_size(&self) -> u32 {
|
||||||
|
match self {
|
||||||
|
Attention::FlashDecoding => 256,
|
||||||
|
Attention::FlashInfer => 1,
|
||||||
|
Attention::Paged => 16,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ParseError;
|
pub struct ParseError;
|
||||||
|
|
||||||
|
@ -45,13 +55,20 @@ impl std::str::FromStr for Attention {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
pub(crate) struct VertexInstance {
|
pub(crate) struct GenerateVertexInstance {
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub parameters: Option<GenerateParameters>,
|
pub parameters: Option<GenerateParameters>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum VertexInstance {
|
||||||
|
Generate(GenerateVertexInstance),
|
||||||
|
Chat(ChatRequest),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
#[derive(Deserialize, ToSchema)]
|
||||||
pub(crate) struct VertexRequest {
|
pub(crate) struct VertexRequest {
|
||||||
#[serde(rename = "instances")]
|
#[serde(rename = "instances")]
|
||||||
|
@ -840,10 +857,10 @@ pub(crate) struct ChatRequest {
|
||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
/// A prompt to be appended before the tools
|
/// A prompt to be appended before the tools
|
||||||
#[serde(default = "default_tool_prompt")]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
nullable = true,
|
nullable = true,
|
||||||
example = "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\""
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
)]
|
)]
|
||||||
pub tool_prompt: Option<String>,
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
|
@ -865,10 +882,8 @@ pub(crate) struct ChatRequest {
|
||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_tool_prompt() -> Option<String> {
|
pub fn default_tool_prompt() -> String {
|
||||||
Some(
|
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
|
||||||
"\nYou will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n".to_string(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
|
||||||
|
@ -910,7 +925,7 @@ impl From<ToolTypeDeserializer> for ToolChoice {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]
|
||||||
pub struct Tools {
|
pub struct JsonSchemaTool {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
functions_map: FunctionsMap,
|
functions_map: FunctionsMap,
|
||||||
properties: Properties,
|
properties: Properties,
|
||||||
|
@ -968,8 +983,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
add_generation_prompt: bool,
|
add_generation_prompt: bool,
|
||||||
tools: Option<&'a str>,
|
tools: Option<Vec<Tool>>,
|
||||||
tools_prompt: Option<&'a str>,
|
|
||||||
guideline: Option<&'a str>,
|
guideline: Option<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1075,6 +1089,16 @@ pub(crate) struct GenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
#[serde(default = "default_parameters")]
|
#[serde(default = "default_parameters")]
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
|
|
||||||
|
/// This is used internally because some requests
|
||||||
|
/// already contain the templated input therefore
|
||||||
|
/// we shouldn't add the special tokens.
|
||||||
|
#[serde(default = "default_true", skip)]
|
||||||
|
pub add_special_tokens: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_true() -> bool {
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
|
@ -1092,6 +1116,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
||||||
fn from(req: CompatGenerateRequest) -> Self {
|
fn from(req: CompatGenerateRequest) -> Self {
|
||||||
Self {
|
Self {
|
||||||
inputs: req.inputs,
|
inputs: req.inputs,
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: req.parameters,
|
parameters: req.parameters,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1243,6 +1268,34 @@ pub(crate) struct ErrorResponse {
|
||||||
pub error_type: String,
|
pub error_type: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct ModelInfo {
|
||||||
|
#[schema(example = "gpt2")]
|
||||||
|
pub id: String,
|
||||||
|
#[schema(example = "model")]
|
||||||
|
pub object: String,
|
||||||
|
#[schema(example = 1686935002)]
|
||||||
|
pub created: u64,
|
||||||
|
#[schema(example = "openai")]
|
||||||
|
pub owned_by: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, ToSchema)]
|
||||||
|
pub(crate) struct ModelsInfo {
|
||||||
|
#[schema(example = "list")]
|
||||||
|
pub object: String,
|
||||||
|
pub data: Vec<ModelInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ModelsInfo {
|
||||||
|
fn default() -> Self {
|
||||||
|
ModelsInfo {
|
||||||
|
object: "list".to_string(),
|
||||||
|
data: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::ChatTokenizeResponse;
|
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
|
@ -23,7 +23,8 @@ use crate::{
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||||
VertexResponse,
|
VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||||
|
use crate::{ModelInfo, ModelsInfo};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||||
|
@ -116,6 +117,29 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||||
Json(info.0)
|
Json(info.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/v1/models",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Served model info", body = ModelInfo),
|
||||||
|
(status = 404, description = "Model not found", body = ErrorResponse),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(skip(info))]
|
||||||
|
/// Get model info
|
||||||
|
async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
||||||
|
Json(ModelsInfo {
|
||||||
|
data: vec![ModelInfo {
|
||||||
|
id: info.0.model_id.clone(),
|
||||||
|
object: "model".to_string(),
|
||||||
|
created: 0, // TODO: determine how to get this
|
||||||
|
owned_by: info.0.model_id.clone(),
|
||||||
|
}],
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
|
@ -146,7 +170,7 @@ async fn get_chat_tokenize(
|
||||||
} = req;
|
} = req;
|
||||||
|
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
|
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
|
@ -158,6 +182,7 @@ async fn get_chat_tokenize(
|
||||||
|
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
|
add_special_tokens: false,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -754,6 +779,7 @@ async fn completions(
|
||||||
.iter()
|
.iter()
|
||||||
.map(|prompt| GenerateRequest {
|
.map(|prompt| GenerateRequest {
|
||||||
inputs: prompt.to_string(),
|
inputs: prompt.to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -1158,14 +1184,16 @@ async fn chat_completions(
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
let logprobs = logprobs.unwrap_or(false);
|
let logprobs = logprobs.unwrap_or(false);
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
let tool_prompt = tool_prompt
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(default_tool_prompt);
|
||||||
let stop = stop.unwrap_or_default();
|
let stop = stop.unwrap_or_default();
|
||||||
// enable greedy only when temperature is 0
|
// enable greedy only when temperature is 0
|
||||||
let (do_sample, temperature) = match temperature {
|
let (do_sample, temperature) = match temperature {
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
let (inputs, grammar, tool_grammar) = prepare_chat_input(
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
&infer,
|
&infer,
|
||||||
response_format,
|
response_format,
|
||||||
tools,
|
tools,
|
||||||
|
@ -1178,6 +1206,7 @@ async fn chat_completions(
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
|
add_special_tokens: false,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -1221,7 +1250,7 @@ async fn chat_completions(
|
||||||
});
|
});
|
||||||
|
|
||||||
// replace the content with the tool calls if grammar is present
|
// replace the content with the tool calls if grammar is present
|
||||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
let (content, tool_calls) = if using_tools {
|
||||||
(None, Some(vec![stream_token.token.text]))
|
(None, Some(vec![stream_token.token.text]))
|
||||||
} else {
|
} else {
|
||||||
let content = if !stream_token.token.special {
|
let content = if !stream_token.token.special {
|
||||||
|
@ -1275,7 +1304,7 @@ async fn chat_completions(
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
let (tool_calls, output) = if using_tools {
|
||||||
let gen_text_value: Value =
|
let gen_text_value: Value =
|
||||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||||
InferError::ToolError(format!(
|
InferError::ToolError(format!(
|
||||||
|
@ -1377,13 +1406,14 @@ async fn vertex_compatibility(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process all instances
|
// Prepare futures for all instances
|
||||||
let predictions = req
|
let mut futures = Vec::with_capacity(req.instances.len());
|
||||||
.instances
|
|
||||||
.iter()
|
for instance in req.instances.iter() {
|
||||||
.map(|instance| {
|
let generate_request = match instance {
|
||||||
let generate_request = GenerateRequest {
|
VertexInstance::Generate(instance) => GenerateRequest {
|
||||||
inputs: instance.inputs.clone(),
|
inputs: instance.inputs.clone(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||||
|
@ -1392,31 +1422,117 @@ async fn vertex_compatibility(
|
||||||
decoder_input_details: true,
|
decoder_input_details: true,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
};
|
},
|
||||||
|
VertexInstance::Chat(instance) => {
|
||||||
|
let ChatRequest {
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
messages,
|
||||||
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
tool_prompt,
|
||||||
|
temperature,
|
||||||
|
response_format,
|
||||||
|
guideline,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_p,
|
||||||
|
top_logprobs,
|
||||||
|
..
|
||||||
|
} = instance.clone();
|
||||||
|
|
||||||
async {
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
generate_internal(
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
Extension(infer.clone()),
|
let tool_prompt = tool_prompt
|
||||||
compute_type.clone(),
|
.filter(|s| !s.is_empty())
|
||||||
Json(generate_request),
|
.unwrap_or_else(default_tool_prompt);
|
||||||
span.clone(),
|
let stop = stop.unwrap_or_default();
|
||||||
)
|
// enable greedy only when temperature is 0
|
||||||
.await
|
let (do_sample, temperature) = match temperature {
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
.map_err(|_| {
|
other => (true, other),
|
||||||
(
|
};
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
||||||
Json(ErrorResponse {
|
&infer,
|
||||||
error: "Incomplete generation".into(),
|
response_format,
|
||||||
error_type: "Incomplete generation".into(),
|
tools,
|
||||||
}),
|
tool_choice,
|
||||||
)
|
&tool_prompt,
|
||||||
})
|
guideline,
|
||||||
|
messages,
|
||||||
|
) {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: format!("Failed to prepare chat input: {}", e),
|
||||||
|
error_type: "Input preparation error".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
GenerateRequest {
|
||||||
|
inputs: inputs.to_string(),
|
||||||
|
add_special_tokens: false,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_k: None,
|
||||||
|
top_p,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop,
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: top_logprobs,
|
||||||
|
grammar,
|
||||||
|
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
};
|
||||||
.collect::<FuturesUnordered<_>>()
|
|
||||||
.try_collect::<Vec<_>>()
|
let infer_clone = infer.clone();
|
||||||
.await?;
|
let compute_type_clone = compute_type.clone();
|
||||||
|
let span_clone = span.clone();
|
||||||
|
|
||||||
|
futures.push(async move {
|
||||||
|
generate_internal(
|
||||||
|
Extension(infer_clone),
|
||||||
|
compute_type_clone,
|
||||||
|
Json(generate_request),
|
||||||
|
span_clone,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||||
|
let results = futures::future::join_all(futures).await;
|
||||||
|
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||||
|
let predictions = predictions?;
|
||||||
|
|
||||||
let response = VertexResponse { predictions };
|
let response = VertexResponse { predictions };
|
||||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||||
|
@ -1499,6 +1615,7 @@ chat_completions,
|
||||||
completions,
|
completions,
|
||||||
tokenize,
|
tokenize,
|
||||||
metrics,
|
metrics,
|
||||||
|
openai_get_model_info,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
|
@ -1551,6 +1668,7 @@ ToolCall,
|
||||||
Function,
|
Function,
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ModelInfo,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
|
@ -2244,7 +2362,8 @@ async fn start(
|
||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
.route("/metrics", get(metrics));
|
.route("/metrics", get(metrics))
|
||||||
|
.route("/v1/models", get(openai_get_model_info));
|
||||||
|
|
||||||
// Conditional AWS Sagemaker route
|
// Conditional AWS Sagemaker route
|
||||||
let aws_sagemaker_route = if messages_api_enabled {
|
let aws_sagemaker_route = if messages_api_enabled {
|
||||||
|
@ -2539,7 +2658,7 @@ fn create_post_processor(
|
||||||
Ok(post_processor)
|
Ok(post_processor)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
|
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||||
|
|
||||||
fn prepare_chat_input(
|
fn prepare_chat_input(
|
||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
|
@ -2556,19 +2675,139 @@ fn prepare_chat_input(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
||||||
if let Some(format) = response_format {
|
if let Some(format) = response_format {
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||||
return Ok((inputs, Some(format), None));
|
return Ok((inputs, Some(format), false));
|
||||||
}
|
}
|
||||||
|
|
||||||
// if tools are set, apply the tool grammar and then the chat template
|
// when no response_format is set and tools are included, apply the chat template with the tools
|
||||||
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
|
// to generate inputs
|
||||||
let grammar = tool_grammar
|
if let Some(tools) = tools {
|
||||||
.as_ref()
|
let (updated_tools, tool_schema) = ToolGrammar::apply(tools, tool_choice)?;
|
||||||
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
|
||||||
let tools_grammar_prompt = tool_grammar
|
let grammar = tool_schema
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
|
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||||
let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?;
|
|
||||||
Ok((inputs, grammar, tool_grammar))
|
let inputs: String = infer.apply_chat_template(
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
Some((updated_tools, tool_prompt.into())),
|
||||||
|
)?;
|
||||||
|
return Ok((inputs, grammar, tool_schema.is_some()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||||
|
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||||
|
Ok((inputs, None, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::ChatTemplateVersions;
|
||||||
|
use crate::HubTokenizerConfig;
|
||||||
|
use crate::TokenizerConfigToken;
|
||||||
|
use crate::Tool;
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_prepare_chat_input() {
|
||||||
|
// Mock Backend to avoid network requests
|
||||||
|
struct MockBackend;
|
||||||
|
|
||||||
|
impl Backend for MockBackend {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
_request: crate::validation::ValidGenerateRequest,
|
||||||
|
) -> Result<
|
||||||
|
tokio_stream::wrappers::UnboundedReceiverStream<
|
||||||
|
Result<InferStreamResponse, InferError>,
|
||||||
|
>,
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
|
unimplemented!("Never called in this test");
|
||||||
|
}
|
||||||
|
fn health<'a, 'async_trait>(
|
||||||
|
&'a self,
|
||||||
|
_current_health: bool,
|
||||||
|
) -> core::pin::Pin<
|
||||||
|
Box<dyn core::future::Future<Output = bool> + core::marker::Send + 'async_trait>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
'a: 'async_trait,
|
||||||
|
Self: 'async_trait,
|
||||||
|
{
|
||||||
|
unimplemented!("Never called in this test");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let backend = MockBackend {};
|
||||||
|
|
||||||
|
let mut tokenizer_config = HubTokenizerConfig::default();
|
||||||
|
|
||||||
|
// mock tokenizer config values
|
||||||
|
tokenizer_config.bos_token = Some(TokenizerConfigToken::String("<s>".to_string()));
|
||||||
|
tokenizer_config.eos_token = Some(TokenizerConfigToken::String("</s>".to_string()));
|
||||||
|
tokenizer_config.chat_template = Some(
|
||||||
|
ChatTemplateVersions::Single("{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let infer = Infer::new(
|
||||||
|
backend,
|
||||||
|
Validation::new(1, None, None, None, 1, 1, 1, 1, 1, false),
|
||||||
|
1,
|
||||||
|
tokenizer_config,
|
||||||
|
HubProcessorConfig::default(),
|
||||||
|
);
|
||||||
|
let response_format = None;
|
||||||
|
let tools = Some(vec![Tool {
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: FunctionDefinition {
|
||||||
|
name: "get_current_weather".to_string(),
|
||||||
|
description: Some("Get the current weather".to_string()),
|
||||||
|
arguments: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}]);
|
||||||
|
let tool_prompt = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.";
|
||||||
|
let guideline = None;
|
||||||
|
let messages = vec![Message {
|
||||||
|
name: None,
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: MessageContent::SingleText(
|
||||||
|
"What is the weather like in New York?".to_string(),
|
||||||
|
),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let result = prepare_chat_input(
|
||||||
|
&infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
ToolChoice(None),
|
||||||
|
tool_prompt,
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let (inputs, _grammar, using_tools) = result.unwrap();
|
||||||
|
assert_eq!(using_tools, true);
|
||||||
|
assert_eq!(inputs, "<s>[AVAILABLE_TOOLS] [{\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}, \"description\": \"Get the current weather\", \"name\": \"get_current_weather\"}}, {\"type\": \"function\", \"function\": {\"arguments\": {\"properties\":{\"error\":{\"description\":\"The error or issue to notify\",\"type\":\"string\"}},\"required\":[\"error\"],\"type\":\"object\"}, \"description\": \"Notify an error or issue\", \"name\": \"notify_error\"}}][/AVAILABLE_TOOLS][INST] What is the weather like in New York?\n---\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.[/INST]".to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,6 +95,7 @@ impl Validation {
|
||||||
pub async fn tokenize(
|
pub async fn tokenize(
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
|
add_special_tokens: bool,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
|
@ -104,7 +105,11 @@ impl Validation {
|
||||||
// Send request to the background validation task
|
// Send request to the background validation task
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
sender
|
sender
|
||||||
.send(((inputs, truncate), response_sender, Span::current()))
|
.send((
|
||||||
|
(inputs, add_special_tokens, truncate),
|
||||||
|
response_sender,
|
||||||
|
Span::current(),
|
||||||
|
))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
|
@ -121,11 +126,15 @@ impl Validation {
|
||||||
async fn validate_input(
|
async fn validate_input(
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
|
add_special_tokens: bool,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self
|
||||||
|
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
// Create response channel
|
// Create response channel
|
||||||
let input_length = if let Some(truncate) = truncate {
|
let input_length = if let Some(truncate) = truncate {
|
||||||
std::cmp::min(encoding.len(), truncate)
|
std::cmp::min(encoding.len(), truncate)
|
||||||
|
@ -158,7 +167,8 @@ impl Validation {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let input_ids = encoding.get_ids()[..input_length].to_owned();
|
let ids = encoding.get_ids();
|
||||||
|
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||||
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
|
||||||
|
@ -324,7 +334,12 @@ impl Validation {
|
||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(
|
||||||
|
request.inputs,
|
||||||
|
request.add_special_tokens,
|
||||||
|
truncate,
|
||||||
|
max_new_tokens,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
||||||
|
@ -401,6 +416,7 @@ impl Validation {
|
||||||
Ok(ValidGenerateRequest {
|
Ok(ValidGenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
input_ids: input_ids.map(Arc::new),
|
input_ids: input_ids.map(Arc::new),
|
||||||
|
add_special_tokens: request.add_special_tokens,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
input_length: input_length as u32,
|
input_length: input_length as u32,
|
||||||
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
|
@ -449,12 +465,15 @@ fn tokenizer_worker(
|
||||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||||
) {
|
) {
|
||||||
// Loop over requests
|
// Loop over requests
|
||||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||||
|
receiver.blocking_recv()
|
||||||
|
{
|
||||||
parent_span.in_scope(|| {
|
parent_span.in_scope(|| {
|
||||||
response_tx
|
response_tx
|
||||||
.send(prepare_input(
|
.send(prepare_input(
|
||||||
inputs,
|
inputs,
|
||||||
truncate,
|
truncate,
|
||||||
|
add_special_tokens,
|
||||||
&tokenizer,
|
&tokenizer,
|
||||||
config.as_ref(),
|
config.as_ref(),
|
||||||
preprocessor_config.as_ref(),
|
preprocessor_config.as_ref(),
|
||||||
|
@ -591,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||||
fn prepare_input(
|
fn prepare_input(
|
||||||
inputs: String,
|
inputs: String,
|
||||||
_truncate: Option<usize>,
|
_truncate: Option<usize>,
|
||||||
|
add_special_tokens: bool,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
config: Option<&Config>,
|
config: Option<&Config>,
|
||||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||||
|
@ -628,14 +648,14 @@ fn prepare_input(
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
let encoding = tokenizer
|
let encoding = tokenizer
|
||||||
.encode(tokenizer_query, true)
|
.encode(tokenizer_query, add_special_tokens)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
Ok((encoding, input_chunks))
|
Ok((encoding, input_chunks))
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, Option<usize>),
|
(String, bool, Option<usize>),
|
||||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
@ -720,6 +740,7 @@ pub struct ValidGenerateRequest {
|
||||||
pub input_ids: Option<Arc<Vec<u32>>>,
|
pub input_ids: Option<Arc<Vec<u32>>>,
|
||||||
pub input_length: u32,
|
pub input_length: u32,
|
||||||
pub truncate: u32,
|
pub truncate: u32,
|
||||||
|
pub add_special_tokens: bool,
|
||||||
pub decoder_input_details: bool,
|
pub decoder_input_details: bool,
|
||||||
pub parameters: ValidParameters,
|
pub parameters: ValidParameters,
|
||||||
pub stopping_parameters: ValidStoppingParameters,
|
pub stopping_parameters: ValidStoppingParameters,
|
||||||
|
@ -826,7 +847,7 @@ mod tests {
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||||
|
@ -861,7 +882,7 @@ mod tests {
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||||
|
@ -895,6 +916,7 @@ mod tests {
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: Some(2),
|
best_of: Some(2),
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
|
@ -934,6 +956,7 @@ mod tests {
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
|
@ -949,6 +972,7 @@ mod tests {
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(0.99),
|
top_p: Some(0.99),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
|
@ -964,6 +988,7 @@ mod tests {
|
||||||
let valid_request = validation
|
let valid_request = validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: None,
|
top_p: None,
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
|
@ -1002,6 +1027,7 @@ mod tests {
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(5),
|
top_n_tokens: Some(5),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
|
@ -1017,6 +1043,7 @@ mod tests {
|
||||||
validation
|
validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(4),
|
top_n_tokens: Some(4),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
|
@ -1029,6 +1056,7 @@ mod tests {
|
||||||
validation
|
validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(0),
|
top_n_tokens: Some(0),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
|
@ -1041,6 +1069,7 @@ mod tests {
|
||||||
let valid_request = validation
|
let valid_request = validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
|
@ -1089,6 +1118,7 @@ mod tests {
|
||||||
let chunks = match validation
|
let chunks = match validation
|
||||||
.tokenize(
|
.tokenize(
|
||||||
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
|
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
|
||||||
|
true,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -1148,6 +1178,7 @@ mod tests {
|
||||||
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
|
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
|
||||||
PIXEL_GIF, PIXEL_GIF
|
PIXEL_GIF, PIXEL_GIF
|
||||||
),
|
),
|
||||||
|
true,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[toolchain]
|
[toolchain]
|
||||||
# Released on: June 13, 2024
|
# Released on: June 13, 2024
|
||||||
# https://releases.rs/docs/1.79.0/
|
# https://releases.rs/docs/1.79.0/
|
||||||
channel = "1.79.0"
|
channel = "1.80.0"
|
||||||
components = ["rustfmt", "clippy"]
|
components = ["rustfmt", "clippy"]
|
||||||
|
|
|
@ -7,6 +7,7 @@ include Makefile-selective-scan
|
||||||
include Makefile-lorax-punica
|
include Makefile-lorax-punica
|
||||||
include Makefile-fbgemm
|
include Makefile-fbgemm
|
||||||
include Makefile-exllamav2
|
include Makefile-exllamav2
|
||||||
|
include Makefile-flashinfer
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
fbgemm_commit := v0.8.0
|
fbgemm_commit := v0.8.0
|
||||||
|
|
||||||
build-fbgemm:
|
build-fbgemm:
|
||||||
git clone https://github.com/pytorch/FBGEMM.git fbgemm && \
|
@if [ ! -d "fbgemm" ]; then \
|
||||||
|
git clone https://github.com/pytorch/FBGEMM.git fbgemm; \
|
||||||
|
fi
|
||||||
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
|
cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \
|
||||||
git submodule update --init --recursive && \
|
git submodule update --init --recursive && \
|
||||||
cd fbgemm_gpu && \
|
cd fbgemm_gpu && \
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
install-flashinfer:
|
||||||
|
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4
|
|
@ -3237,11 +3237,6 @@ files = [
|
||||||
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
|
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
|
||||||
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
|
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
|
||||||
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
|
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
|
||||||
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
|
|
||||||
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
|
|
||||||
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
|
|
||||||
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
|
|
||||||
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import os
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
os.environ["USE_PREFIX_CACHING"] = "1"
|
||||||
|
os.environ["ATTENTION"] = "flashinfer"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_parameters():
|
def default_pb_parameters():
|
||||||
|
|
|
@ -1,6 +1,54 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from text_generation_server.utils.adapter import get_attn_weights, get_mlp_weights
|
from text_generation_server.utils.adapter import (
|
||||||
|
get_attn_weights,
|
||||||
|
get_mlp_weights,
|
||||||
|
parse_lora_adapters,
|
||||||
|
AdapterInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_empty():
|
||||||
|
assert parse_lora_adapters(None) == []
|
||||||
|
assert parse_lora_adapters("") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_single():
|
||||||
|
result = parse_lora_adapters("adapter1")
|
||||||
|
assert result == [AdapterInfo(id="adapter1", path=None, revision=None)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_with_path():
|
||||||
|
result = parse_lora_adapters("adapter1=path/to/adapter1")
|
||||||
|
assert result == [
|
||||||
|
AdapterInfo(id="adapter1", path="path/to/adapter1", revision=None)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_with_path_and_revision():
|
||||||
|
result = parse_lora_adapters("adapter1=path/to/adapter1@main")
|
||||||
|
assert result == [
|
||||||
|
AdapterInfo(id="adapter1", path="path/to/adapter1", revision="main")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_multiple():
|
||||||
|
result = parse_lora_adapters(
|
||||||
|
"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev"
|
||||||
|
)
|
||||||
|
assert result == [
|
||||||
|
AdapterInfo(id="adapter1", path=None, revision=None),
|
||||||
|
AdapterInfo(id="adapter2", path="path/to/adapter2", revision=None),
|
||||||
|
AdapterInfo(id="adapter3", path="path/to/adapter3", revision="dev"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_lora_adapters_invalid_format():
|
||||||
|
try:
|
||||||
|
parse_lora_adapters("adapter1,invalid=format=test,adapter3")
|
||||||
|
assert False, "Should have raised ValueError"
|
||||||
|
except ValueError as e:
|
||||||
|
assert str(e) == "Invalid LoRA adapter format: invalid=format=test"
|
||||||
|
|
||||||
|
|
||||||
def test_get_attn_weights():
|
def test_get_attn_weights():
|
||||||
|
@ -22,6 +70,10 @@ def test_get_attn_weights():
|
||||||
"model.layers.2.self_attn.k_proj",
|
"model.layers.2.self_attn.k_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
),
|
),
|
||||||
|
(2, "qkv_proj"): (
|
||||||
|
"model.layers.2.self_attn.qkv_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
(2, "v_proj"): (
|
(2, "v_proj"): (
|
||||||
"model.layers.2.self_attn.v_proj",
|
"model.layers.2.self_attn.v_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
|
@ -115,6 +167,10 @@ def test_get_attn_weights_llama_compatibility():
|
||||||
"model.layers.2.self_attn.k_proj",
|
"model.layers.2.self_attn.k_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
),
|
),
|
||||||
|
(2, "qkv_proj"): (
|
||||||
|
"model.layers.2.self_attn.qkv_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
(2, "v_proj"): (
|
(2, "v_proj"): (
|
||||||
"model.layers.2.self_attn.v_proj",
|
"model.layers.2.self_attn.v_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
|
@ -155,6 +211,10 @@ def test_get_attn_weights_gemma_compatibility():
|
||||||
"model.layers.2.self_attn.k_proj",
|
"model.layers.2.self_attn.k_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
),
|
),
|
||||||
|
(2, "qkv_proj"): (
|
||||||
|
"model.layers.2.self_attn.qkv_proj",
|
||||||
|
mock_layer.self_attn.query_key_value,
|
||||||
|
),
|
||||||
(2, "v_proj"): (
|
(2, "v_proj"): (
|
||||||
"model.layers.2.self_attn.v_proj",
|
"model.layers.2.self_attn.v_proj",
|
||||||
mock_layer.self_attn.query_key_value,
|
mock_layer.self_attn.query_key_value,
|
||||||
|
|
|
@ -6,7 +6,12 @@ from .common import Seqlen
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
SUPPORTS_WINDOWING,
|
||||||
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
|
|
|
@ -9,26 +9,46 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
|
prefix_lengths: torch.Tensor
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
max_q: int
|
||||||
|
max_k: int
|
||||||
|
|
||||||
def __init__(self, input_lengths):
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_lengths,
|
||||||
|
prefix_lengths,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=None,
|
||||||
|
max_k=None,
|
||||||
|
):
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
|
self.prefix_lengths = prefix_lengths
|
||||||
device = self.input_lengths.device
|
device = self.input_lengths.device
|
||||||
shape = self.input_lengths.shape
|
shape = self.input_lengths.shape
|
||||||
cu_seqlen_q = torch.arange(
|
if cu_seqlen_q is None:
|
||||||
shape[0] + 1,
|
cu_seqlen_q = torch.arange(
|
||||||
device=device,
|
shape[0] + 1,
|
||||||
dtype=torch.int32,
|
device=device,
|
||||||
)
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
max_q = 1
|
||||||
|
else:
|
||||||
|
assert max_q is not None
|
||||||
|
assert max_k is not None
|
||||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
# Although FA2 might not want the clamping
|
# Although FA2 might not want the clamping
|
||||||
# cu_seqlen_k[0] = 0
|
# cu_seqlen_k[0] = 0
|
||||||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
total = self.input_lengths + self.prefix_lengths
|
||||||
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
self.cu_seqlen_k = cu_seqlen_k
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
|
self.max_q = max_q
|
||||||
|
self.max_k = max_k
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
|
@ -39,6 +59,11 @@ else:
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
|
prefix_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: torch.Tensor
|
||||||
|
max_q: int
|
||||||
|
max_k: int
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
|
raise NotImplementedError("Not implemented seqlen for paged")
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||||
|
|
|
@ -76,7 +76,7 @@ def paged_attention(
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import decode_state
|
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||||
|
|
||||||
return decode_state.get().forward(
|
return decode_state.get().forward(
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
|
@ -221,36 +221,37 @@ SUPPORTS_WINDOWING = V2
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
softcap=0.0,
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
from text_generation_server.layers.attention.flash_infer import prefill_state
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
prefill_with_paged_kv_state,
|
||||||
|
)
|
||||||
|
|
||||||
return prefill_state.get().forward(
|
return prefill_with_paged_kv_state.get().forward(
|
||||||
q,
|
q.contiguous(),
|
||||||
k,
|
|
||||||
v,
|
|
||||||
causal=causal,
|
causal=causal,
|
||||||
window_left=window_size_left,
|
paged_kv_cache=(key_cache, value_cache),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
|
window_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif V2:
|
elif V2:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
@ -261,17 +262,17 @@ elif V2:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_k,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
block_tables,
|
||||||
None,
|
None,
|
||||||
None,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_k,
|
||||||
max_s,
|
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
|
@ -289,6 +290,8 @@ else:
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_s,
|
max_s,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
|
|
@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
|
||||||
"prefill_state"
|
"prefill_state"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prefill_with_paged_kv_state: ContextVar[
|
||||||
|
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
||||||
|
] = ContextVar("prefill_with_paged_kv_state")
|
||||||
|
|
||||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||||
"decode_state"
|
"decode_state"
|
||||||
)
|
)
|
||||||
|
@ -24,6 +28,78 @@ def get_workspace(device):
|
||||||
return workspace
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
def create_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Create a prefill state that uses the KV cache."""
|
||||||
|
workspace_buffer = get_workspace(device)
|
||||||
|
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_prefill_with_paged_kv_state(
|
||||||
|
*,
|
||||||
|
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
page_size: int,
|
||||||
|
query_dtype: str = "float16",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Context manager to set the active flashinfer prefill state to the given
|
||||||
|
`state` and parameters. This state will be used by all calls to the
|
||||||
|
`attention` function while the context manager is active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indptr = torch.zeros(
|
||||||
|
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
# Round up to page size and then calculate the cumulative sum to get
|
||||||
|
# the indices into the block table.
|
||||||
|
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||||
|
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||||
|
indptr[1:].cumsum_(-1)
|
||||||
|
|
||||||
|
# Get the lengths of the last page in a block.
|
||||||
|
if page_size == 1:
|
||||||
|
last_page_len = torch.ones(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_page_len = torch.empty(
|
||||||
|
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||||
|
)
|
||||||
|
torch.sub(input_lengths, 1, out=last_page_len)
|
||||||
|
last_page_len.remainder_(page_size)
|
||||||
|
last_page_len += 1
|
||||||
|
|
||||||
|
token = prefill_with_paged_kv_state.set(state)
|
||||||
|
try:
|
||||||
|
state.begin_forward(
|
||||||
|
qo_indptr=cu_seqlens,
|
||||||
|
paged_kv_indptr=indptr,
|
||||||
|
paged_kv_indices=block_tables,
|
||||||
|
paged_kv_last_page_len=last_page_len,
|
||||||
|
num_qo_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_size,
|
||||||
|
q_data_type=query_dtype,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
state.end_forward()
|
||||||
|
if token is not None:
|
||||||
|
prefill_with_paged_kv_state.reset(token)
|
||||||
|
|
||||||
|
|
||||||
def create_prefill_state(
|
def create_prefill_state(
|
||||||
*,
|
*,
|
||||||
device: torch.device,
|
device: torch.device,
|
|
@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
@ -23,13 +23,13 @@ def attention(
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
|
|
|
@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if not self.heads:
|
||||||
|
return None
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
return speculative_logits
|
return speculative_logits
|
||||||
|
|
||||||
|
|
|
@ -45,12 +45,107 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
INV_SQRT2 = 2**-0.5
|
||||||
|
|
||||||
|
|
||||||
|
def simple_norm(x: torch.Tensor, eps=1e-06):
|
||||||
|
xf = x
|
||||||
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
|
||||||
|
x = xf.type_as(x)
|
||||||
|
return x * INV_SQRT2
|
||||||
|
|
||||||
|
|
||||||
|
class MLPSpeculatorModelTied(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.n_predict = get_speculate()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.emb = TensorParallelEmbedding(f"{prefix}.emb.0", weights)
|
||||||
|
self.proj0 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.0",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.proj1 = FastLinear.load(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.proj.1",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.head = FastLinear.load(config, f"{prefix}.head.0", weights, bias=False)
|
||||||
|
self.ln = MLPSpeculatorLayerNorm(
|
||||||
|
prefix=f"{prefix}.ln.0",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.vsize = config.vocab_size
|
||||||
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
|
|
||||||
|
# k indicates # of candidates
|
||||||
|
# h indicates # of generated tokens
|
||||||
|
state = hidden_states
|
||||||
|
b = state.size(0)
|
||||||
|
ind = input_ids.unsqueeze(0)
|
||||||
|
all_probs = torch.empty(
|
||||||
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
|
) # b k h v
|
||||||
|
assert (
|
||||||
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
|
for i in range(self.n_predict):
|
||||||
|
# Project and predict
|
||||||
|
z = self.emb(ind)
|
||||||
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
|
if i == 0:
|
||||||
|
state = self.proj0(state) * self.state_weight + z
|
||||||
|
else:
|
||||||
|
state = self.proj1(state) * self.state_weight + z
|
||||||
|
state = self.activation(self.ln(state)) # b k d
|
||||||
|
probs = F.log_softmax(self.head(state), dim=-1) # b k v
|
||||||
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
|
# Update candidate set with new predictions
|
||||||
|
|
||||||
|
# Update distribution set with new logits
|
||||||
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
|
# Update state, log_probs and ind for new predictions
|
||||||
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
|
ind = preds.view(-1, b) # b kk'
|
||||||
|
|
||||||
|
speculative_logits = all_probs
|
||||||
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorModel(torch.nn.Module):
|
class MLPSpeculatorModel(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.n_predict = get_speculate()
|
self.n_predict = get_speculate()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.emb = nn.ModuleList(
|
self.emb = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||||
|
@ -84,13 +179,15 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
|
||||||
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1
|
||||||
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
# TODO
|
|
||||||
self.vsize = config.vocab_size
|
self.vsize = config.vocab_size
|
||||||
self.inner_dim = config.speculator_config["inner_dim"]
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
self.top_k_tokens_per_head = [1] * self.n_predict
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
|
self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(
|
||||||
|
self.inner_dim / 2
|
||||||
|
)
|
||||||
|
self.emb.weight *= self.emb_weight
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -113,7 +210,7 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||||
for i in range(self.n_predict):
|
for i in range(self.n_predict):
|
||||||
# Project and predict
|
# Project and predict
|
||||||
z = self.emb[i](ind)
|
z = self.emb[i](ind)
|
||||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
# z = z.mul(self.emb_weight) # b k d
|
||||||
state = self.proj[i](state) * self.state_weight + z
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
state = self.activation(self.ln[i](state)) # b k d
|
state = self.activation(self.ln[i](state)) # b k d
|
||||||
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||||
|
@ -136,10 +233,11 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorHead(nn.Module):
|
class MLPSpeculatorHead(nn.Module):
|
||||||
def __init__(self, lm_head, mlp_speculator):
|
def __init__(self, lm_head, mlp_speculator, scale_input: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lm_head = lm_head
|
self.lm_head = lm_head
|
||||||
self.mlp_speculator = mlp_speculator
|
self.mlp_speculator = mlp_speculator
|
||||||
|
self.scale_input = scale_input
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor
|
self, input: torch.Tensor
|
||||||
|
@ -150,6 +248,8 @@ class MLPSpeculatorHead(nn.Module):
|
||||||
return logits, None
|
return logits, None
|
||||||
|
|
||||||
input_ids = logits.argmax(dim=-1)
|
input_ids = logits.argmax(dim=-1)
|
||||||
|
if self.scale_input:
|
||||||
|
input = simple_norm(input)
|
||||||
speculative_logits = self.mlp_speculator(input, input_ids)
|
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
@ -171,6 +271,12 @@ class MLPSpeculatorHead(nn.Module):
|
||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
|
||||||
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
tie_weights = config.speculator_config.get("tie_weights", False)
|
||||||
|
if tie_weights:
|
||||||
|
mlp_speculator = MLPSpeculatorModelTied(config, "speculator", weights)
|
||||||
|
else:
|
||||||
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
|
# This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator
|
||||||
|
scale_input = config.speculator_config.get("scale_input", False)
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)
|
||||||
|
|
|
@ -458,6 +458,11 @@ def get_model(
|
||||||
revision=mlp_revision,
|
revision=mlp_revision,
|
||||||
filename=filename,
|
filename=filename,
|
||||||
)
|
)
|
||||||
|
speculator_dir_path = Path(mlp_speculator_config).parent
|
||||||
|
# if these are downloaded, they get converted to safetensors
|
||||||
|
filenames.extend(
|
||||||
|
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
|
||||||
|
)
|
||||||
speculator = {
|
speculator = {
|
||||||
"path": Path(mlp_speculator_config).parent,
|
"path": Path(mlp_speculator_config).parent,
|
||||||
"model_paths": filenames,
|
"model_paths": filenames,
|
||||||
|
@ -497,15 +502,14 @@ def get_model(
|
||||||
else -1
|
else -1
|
||||||
)
|
)
|
||||||
|
|
||||||
should_use_sliding_window = (
|
use_sliding_window = sliding_window is not None and sliding_window != -1
|
||||||
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
|
needs_sliding_window = (
|
||||||
|
max_input_tokens is not None and max_input_tokens > sliding_window
|
||||||
)
|
)
|
||||||
|
if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
|
||||||
if should_use_sliding_window:
|
raise ValueError(
|
||||||
if max_input_tokens is not None and max_input_tokens > sliding_window:
|
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
||||||
raise ValueError(
|
)
|
||||||
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
@ -1255,6 +1259,7 @@ def get_model_with_lora_adapters(
|
||||||
"gate_proj",
|
"gate_proj",
|
||||||
"up_proj",
|
"up_proj",
|
||||||
"down_proj",
|
"down_proj",
|
||||||
|
"qkv_proj",
|
||||||
]
|
]
|
||||||
|
|
||||||
for layer_name in adapter_layers:
|
for layer_name in adapter_layers:
|
||||||
|
@ -1282,7 +1287,7 @@ def get_model_with_lora_adapters(
|
||||||
|
|
||||||
if len(unused_weight_names) > 0:
|
if len(unused_weight_names) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}"
|
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if adapter_tokenizer is not None:
|
if adapter_tokenizer is not None:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -264,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
@ -296,10 +297,10 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
key,
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
value,
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -311,7 +312,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -386,7 +387,7 @@ class FlashCohereLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -400,7 +401,7 @@ class FlashCohereLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -452,7 +453,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
@ -475,7 +476,7 @@ class FlashCohereModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -516,7 +517,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -529,7 +530,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
|
@ -309,7 +310,7 @@ class DbrxAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
@ -335,10 +336,10 @@ class DbrxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -350,7 +351,7 @@ class DbrxAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -387,7 +388,7 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.norm_1(hidden_states, residual)
|
normed_hidden_states, res = self.norm_1(hidden_states, residual)
|
||||||
|
@ -401,7 +402,7 @@ class DbrxNormAttentionNorm(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -620,7 +621,7 @@ class DbrxLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
@ -633,7 +634,7 @@ class DbrxLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -677,7 +678,7 @@ class DbrxModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
@ -699,7 +700,7 @@ class DbrxModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -732,7 +733,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -745,7 +746,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -29,8 +29,8 @@ from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention.common import Seqlen
|
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
@ -298,7 +298,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
if self.q_lora_rank is None:
|
if self.q_lora_rank is None:
|
||||||
|
@ -363,10 +363,10 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
key,
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
value,
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -378,7 +378,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -664,7 +664,7 @@ class DeepseekV2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -678,7 +678,7 @@ class DeepseekV2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -727,7 +727,7 @@ class DeepseekV2Model(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
@ -749,7 +749,7 @@ class DeepseekV2Model(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -779,7 +779,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -792,7 +792,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -25,11 +25,12 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
@ -236,10 +237,10 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
window_size_left=self.window_size,
|
window_size_left=self.window_size,
|
||||||
|
@ -254,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
|
@ -341,7 +342,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -355,7 +356,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -406,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
@ -428,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -475,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -489,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -25,11 +25,12 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
@ -230,10 +231,10 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
)
|
)
|
||||||
|
@ -246,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -318,7 +319,7 @@ class FlashGemmaLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -332,7 +333,7 @@ class FlashGemmaLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -380,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
@ -402,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -447,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -461,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -24,11 +24,12 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
query, key, value = self.query_key_value(hidden_states).split(
|
query, key, value = self.query_key_value(hidden_states).split(
|
||||||
|
@ -230,10 +231,10 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
key,
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
value,
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -245,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -314,7 +315,7 @@ class FlashGPT2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -327,7 +328,7 @@ class FlashGPT2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -380,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -396,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -433,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -449,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
|
|
@ -24,11 +24,12 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -43,7 +44,6 @@ from text_generation_server.layers.rotary import (
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
|
@ -167,7 +167,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
query, key, value = self.query_key_value(hidden_states).split(
|
query, key, value = self.query_key_value(hidden_states).split(
|
||||||
|
@ -192,10 +192,10 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
key,
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
value,
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -207,7 +207,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ class FlashGPTJLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -281,7 +281,7 @@ class FlashGPTJLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -328,7 +328,7 @@ class FlashGPTJModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
@ -351,7 +351,7 @@ class FlashGPTJModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -382,7 +382,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -395,7 +395,7 @@ class FlashGPTJForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
|
|
@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -65,15 +66,15 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||||
prefixes = None
|
prefixes = None
|
||||||
|
|
||||||
if config.model_type == "phi3":
|
if config.model_type == "phi3":
|
||||||
prefix = f"{prefix}.qkv_proj"
|
|
||||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
config,
|
config,
|
||||||
prefix=prefix,
|
prefix=f"{prefix}.qkv_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
prefixes = ["qkv_proj"]
|
||||||
elif config.model_type == "baichuan":
|
elif config.model_type == "baichuan":
|
||||||
prefix = f"{prefix}.W_pack"
|
prefix = f"{prefix}.W_pack"
|
||||||
base_layer = TensorParallelColumnLinear.load_qkv(
|
base_layer = TensorParallelColumnLinear.load_qkv(
|
||||||
|
@ -84,6 +85,7 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_key_value_heads=config.num_key_value_heads,
|
num_key_value_heads=config.num_key_value_heads,
|
||||||
)
|
)
|
||||||
|
prefixes = [prefix]
|
||||||
else:
|
else:
|
||||||
prefixes = ["q_proj", "k_proj", "v_proj"]
|
prefixes = ["q_proj", "k_proj", "v_proj"]
|
||||||
sizes = [
|
sizes = [
|
||||||
|
@ -194,7 +196,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
):
|
):
|
||||||
|
@ -218,10 +220,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -233,7 +235,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -373,7 +375,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
):
|
):
|
||||||
|
@ -388,7 +390,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -477,7 +479,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -502,7 +504,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -546,7 +548,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -560,7 +562,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
|
|
@ -31,6 +31,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
@ -217,10 +218,10 @@ class MistralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
@ -233,7 +234,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -354,7 +355,7 @@ class MistralLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
@ -370,7 +371,7 @@ class MistralLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
@ -422,7 +423,7 @@ class MistralModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -446,7 +447,7 @@ class MistralModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
|
@ -497,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -510,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||||
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -520,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
|
|
@ -35,6 +35,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
|
@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -274,10 +275,10 @@ class MixtralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
@ -290,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -496,7 +497,7 @@ class MixtralLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -511,7 +512,7 @@ class MixtralLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -566,7 +567,7 @@ class MixtralModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -590,7 +591,7 @@ class MixtralModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -625,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -638,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -647,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
|
|
@ -26,11 +26,12 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
@ -171,10 +172,10 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
||||||
qkv[:, 2],
|
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -186,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -256,7 +257,7 @@ class FlashNeoXLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
|
@ -270,7 +271,7 @@ class FlashNeoXLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -294,7 +295,7 @@ class FlashNeoXLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -348,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_in(input_ids)
|
hidden_states = self.embed_in(input_ids)
|
||||||
|
@ -370,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -402,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -415,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -19,6 +19,7 @@ from torch import nn
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
|
@ -34,6 +35,11 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
self.post_vision_tower_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix="vision_tower.vision_model.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.vision_config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
self.multi_modal_projector = TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
|
@ -65,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -84,7 +90,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
|
||||||
image_outputs = self.vision_tower(pixel_values)
|
image_outputs = self.vision_tower(pixel_values)
|
||||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
last_hidden_state = self.post_vision_tower_layernorm(
|
||||||
|
image_outputs.last_hidden_state
|
||||||
|
)
|
||||||
|
image_features = self.multi_modal_projector(last_hidden_state)
|
||||||
|
|
||||||
# mask where image or padding tokens
|
# mask where image or padding tokens
|
||||||
mask = input_ids == self.config.image_token_index
|
mask = input_ids == self.config.image_token_index
|
||||||
|
@ -99,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -24,6 +25,7 @@ from text_generation_server.layers.layernorm import (
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
|
@ -159,7 +161,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
# Compute query, key, value and split
|
# Compute query, key, value and split
|
||||||
|
@ -192,10 +194,10 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -207,7 +209,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -274,7 +276,7 @@ class FlashPhiLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -287,7 +289,7 @@ class FlashPhiLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -339,7 +341,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
@ -361,7 +363,7 @@ class FlashPhiModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -394,7 +396,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -407,7 +409,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -20,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -104,7 +106,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -135,10 +137,10 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
@ -151,7 +153,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -223,7 +225,7 @@ class Qwen2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -238,7 +240,7 @@ class Qwen2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -294,7 +296,7 @@ class Qwen2Model(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -318,7 +320,7 @@ class Qwen2Model(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -359,7 +361,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -372,7 +374,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -381,7 +383,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -19,6 +19,7 @@ from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
@ -206,10 +207,10 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -221,7 +222,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -294,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
@ -324,10 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
||||||
torch.select(kv, dim=2, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -339,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -425,7 +426,7 @@ class FlashRWLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
|
@ -439,7 +440,7 @@ class FlashRWLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -461,7 +462,7 @@ class FlashRWLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -548,7 +549,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
# Layer norm.
|
# Layer norm.
|
||||||
|
@ -563,7 +564,7 @@ class FlashRWLargeLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -624,7 +625,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.word_embeddings(input_ids)
|
hidden_states = self.word_embeddings(input_ids)
|
||||||
|
@ -646,7 +647,7 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -676,7 +677,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -689,7 +690,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -9,6 +9,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -21,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
|
@ -268,7 +270,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states)
|
||||||
|
@ -291,10 +293,10 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
||||||
torch.select(key_value, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -306,7 +308,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -371,7 +373,7 @@ class Block(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
|
@ -381,7 +383,7 @@ class Block(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -435,7 +437,7 @@ class FlashSantacoderModel(nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
@ -452,7 +454,7 @@ class FlashSantacoderModel(nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -484,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -497,7 +499,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
|
|
|
@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -46,6 +47,7 @@ from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
|
@ -209,7 +211,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -240,10 +242,10 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
torch.select(kv, dim=1, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
|
@ -256,7 +258,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -379,7 +381,7 @@ class Starcoder2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
):
|
):
|
||||||
|
@ -394,7 +396,7 @@ class Starcoder2Layer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -447,7 +449,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -471,7 +473,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
)
|
)
|
||||||
|
@ -519,7 +521,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -532,7 +534,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
elif self.max_past is not None:
|
elif self.max_past is not None:
|
||||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||||
# kernel requires the true values
|
# kernel requires the true values
|
||||||
input_lengths = input_lengths.clamp(max=self.max_past_tensor)
|
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
@ -541,7 +543,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s,
|
true_max_s,
|
||||||
prefill_cache_indices,
|
prefill_cache_indices,
|
||||||
|
|
|
@ -25,6 +25,7 @@ from transformers.activations import ACT2FN
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
|
|
@ -23,6 +23,7 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
|
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.models.custom_modeling.vlm import (
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
load_text_model,
|
load_text_model,
|
||||||
load_vision_model,
|
load_vision_model,
|
||||||
|
@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
|
|
|
@ -364,7 +364,6 @@ class SiglipEncoder(nn.Module):
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for idx, encoder_layer in enumerate(self.layers):
|
for idx, encoder_layer in enumerate(self.layers):
|
||||||
hidden_states, _ = encoder_layer(
|
hidden_states, _ = encoder_layer(
|
||||||
|
@ -386,20 +385,11 @@ class SiglipVisionTransformer(nn.Module):
|
||||||
self.encoder = SiglipEncoder(
|
self.encoder = SiglipEncoder(
|
||||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
)
|
)
|
||||||
self.post_layernorm = nn.LayerNorm.load(
|
|
||||||
prefix=f"{prefix}.post_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
||||||
|
@ -412,10 +402,9 @@ class SiglipVisionTransformer(nn.Module):
|
||||||
inputs_embeds=hidden_states,
|
inputs_embeds=hidden_states,
|
||||||
)
|
)
|
||||||
last_hidden_state = encoder_outputs
|
last_hidden_state = encoder_outputs
|
||||||
post_last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
||||||
|
|
||||||
return BaseModelOutputWithPooling(
|
return BaseModelOutputWithPooling(
|
||||||
last_hidden_state=post_last_hidden_state,
|
last_hidden_state=last_hidden_state,
|
||||||
# pooler_output=pooled_output,
|
# pooler_output=pooled_output,
|
||||||
# hidden_states=encoder_outputs,
|
# hidden_states=encoder_outputs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -43,6 +43,7 @@ from text_generation_server.models.globals import (
|
||||||
ATTENTION,
|
ATTENTION,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
CUDA_GRAPHS,
|
CUDA_GRAPHS,
|
||||||
|
TGI_WIGGLE_ROOM,
|
||||||
get_adapter_to_index,
|
get_adapter_to_index,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables_tensor: torch.Tensor
|
block_tables_tensor: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: torch.Tensor
|
slots: torch.Tensor
|
||||||
|
# size [b], containing the number of blocks that can be retrieved from the cache
|
||||||
|
prefix_lens: List[int]
|
||||||
|
prefix_lens_tensor: torch.Tensor
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
|
@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefill_next_token_indices: Optional[torch.tensor]
|
prefill_next_token_indices: Optional[torch.tensor]
|
||||||
prefill_cu_outlens: Optional[List[int]]
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
|
|
||||||
|
# Prefixes
|
||||||
|
prefix_ids: List[List[int]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
all_input_ids_tensor: torch.Tensor
|
all_input_ids_tensor: torch.Tensor
|
||||||
|
@ -182,16 +189,21 @@ class FlashCausalLMBatch(Batch):
|
||||||
def batch_tokenized_inputs(
|
def batch_tokenized_inputs(
|
||||||
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
cls, requests: Iterable[generate_pb2.Request], tokenizer
|
||||||
):
|
):
|
||||||
batch_inputs = []
|
max_length = 0
|
||||||
max_truncation = 0
|
all_input_ids = []
|
||||||
|
batch_size = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
batch_inputs.append(concat_text_chunks(r.input_chunks.chunks))
|
batch_size += 1
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
inputs = concat_text_chunks(r.input_chunks.chunks)
|
||||||
|
input_ids = tokenizer(
|
||||||
batch_tokenized_inputs = tokenizer(
|
inputs,
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
truncation=True,
|
||||||
)["input_ids"]
|
max_length=r.truncate,
|
||||||
return batch_tokenized_inputs
|
add_special_tokens=r.add_special_tokens,
|
||||||
|
)["input_ids"]
|
||||||
|
max_length = max(max_length, len(input_ids))
|
||||||
|
all_input_ids.append(input_ids)
|
||||||
|
return all_input_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tokenized(
|
def from_tokenized(
|
||||||
|
@ -213,6 +225,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
|
@ -230,7 +243,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
cumulative_max_length = 0
|
cumulative_slot_tokens = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
num_blocks = 0
|
num_blocks = 0
|
||||||
|
@ -240,6 +253,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
block_tables = []
|
block_tables = []
|
||||||
slots = []
|
slots = []
|
||||||
|
prefix_lens = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
|
@ -248,12 +262,18 @@ class FlashCausalLMBatch(Batch):
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
orig_input_length = len(tokenized_input)
|
||||||
if (
|
|
||||||
tokenized_input[0] == tokenizer.bos_token_id
|
prefix_len = r.prefix_len
|
||||||
and tokenized_input[1] == tokenizer.bos_token_id
|
assert (
|
||||||
):
|
prefix_len <= orig_input_length
|
||||||
tokenized_input = tokenized_input[1:]
|
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
||||||
|
if prefix_len == orig_input_length:
|
||||||
|
assert prefix_len > 0
|
||||||
|
prefix_len -= 1
|
||||||
|
|
||||||
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
|
tokenized_input = tokenized_input[prefix_len:]
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
@ -264,7 +284,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
request_position_ids = torch.arange(
|
||||||
|
prefix_len, orig_input_length, dtype=torch.int32
|
||||||
|
)
|
||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
|
@ -288,11 +310,17 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Remove one as the first token des not have a past
|
# Remove one as the first token des not have a past
|
||||||
speculative_length = get_speculate()
|
speculative_length = get_speculate()
|
||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
|
||||||
|
# Tokens that need to be mapped to blocks.
|
||||||
|
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
|
# Tokens that need to be mapped to slots. We don't need slots for the
|
||||||
|
# cached prefix (if present).
|
||||||
|
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
|
|
||||||
# blocks and slots can be empty (for example in warmup)
|
# blocks and slots can be empty (for example in warmup)
|
||||||
if not r.blocks:
|
if not r.blocks:
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
|
||||||
request_blocks = [
|
request_blocks = [
|
||||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
]
|
]
|
||||||
|
@ -303,16 +331,20 @@ class FlashCausalLMBatch(Batch):
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
request_blocks = r.blocks
|
request_blocks = r.blocks
|
||||||
request_slots = r.slots
|
request_slots = r.slots[
|
||||||
|
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
|
||||||
|
]
|
||||||
|
|
||||||
block_tables.append(request_blocks)
|
block_tables.append(request_blocks)
|
||||||
slots.extend(request_slots[:total_tokens])
|
|
||||||
|
slots.extend(request_slots)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
num_blocks += len(request_blocks)
|
num_blocks += len(request_blocks)
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_slot_tokens)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
cumulative_max_length,
|
cumulative_slot_tokens,
|
||||||
cumulative_max_length + input_length,
|
cumulative_slot_tokens + input_length,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
@ -348,7 +380,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += total_tokens
|
cumulative_slot_tokens += slot_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
|
@ -425,12 +457,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
block_tables_tensor = torch.zeros(
|
block_tables_tensor = torch.zeros(
|
||||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
for i, request_blocks in enumerate(block_tables):
|
for i, request_blocks in enumerate(block_tables):
|
||||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
|
@ -445,6 +479,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
|
@ -455,6 +491,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -510,8 +547,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
prefix_lens = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
|
@ -533,11 +572,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
prefix_len = self.prefix_lens[idx]
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
|
@ -582,6 +624,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
prefix_lens_tensor = self.prefix_lens_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
|
@ -617,10 +660,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -681,6 +727,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
|
||||||
(total_batch_size, max_blocks)
|
(total_batch_size, max_blocks)
|
||||||
)
|
)
|
||||||
|
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
|
||||||
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
|
||||||
(total_batch_size, max_length)
|
(total_batch_size, max_length)
|
||||||
)
|
)
|
||||||
|
@ -698,7 +745,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
block_tables = []
|
block_tables = []
|
||||||
|
prefix_lens = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
|
@ -760,10 +809,14 @@ class FlashCausalLMBatch(Batch):
|
||||||
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
start_index:end_index, : batch.block_tables_tensor.shape[1]
|
||||||
] = batch.block_tables_tensor[:, :max_blocks]
|
] = batch.block_tables_tensor[:, :max_blocks]
|
||||||
|
|
||||||
|
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
block_tables.extend(batch.block_tables)
|
block_tables.extend(batch.block_tables)
|
||||||
|
prefix_lens.extend(batch.prefix_lens)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
prefix_ids.extend(batch.prefix_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
|
@ -809,6 +862,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
|
@ -820,6 +875,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
|
prefix_ids=prefix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -942,7 +998,7 @@ class FlashCausalLM(Model):
|
||||||
config.sliding_window = None
|
config.sliding_window = None
|
||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
|
@ -972,19 +1028,22 @@ class FlashCausalLM(Model):
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_prefill_state,
|
create_prefill_state,
|
||||||
create_decode_state,
|
create_decode_state,
|
||||||
|
create_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefill_state = create_prefill_state(device=device)
|
self.prefill_state = create_prefill_state(device=device)
|
||||||
|
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
if not CUDA_GRAPHS:
|
self.decode_state = create_decode_state(
|
||||||
self.decode_state = create_decode_state(
|
device=device,
|
||||||
device=device,
|
num_heads=self.num_heads,
|
||||||
num_heads=self.num_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -1076,12 +1135,23 @@ class FlashCausalLM(Model):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
input_lengths = [max_s] * bs
|
||||||
block_tables = (
|
prefix_lengths = [0] * bs
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
input_lengths_tensor = (
|
||||||
.repeat(bs)
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
.reshape((bs, max_bt))
|
|
||||||
)
|
)
|
||||||
|
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
self.cuda_graphs[bs] = {
|
self.cuda_graphs[bs] = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
@ -1089,14 +1159,21 @@ class FlashCausalLM(Model):
|
||||||
"kv_cache": self.kv_cache,
|
"kv_cache": self.kv_cache,
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths,
|
"input_lengths": input_lengths_tensor,
|
||||||
|
"prefix_lengths": prefix_lengths_tensor,
|
||||||
}
|
}
|
||||||
input_lengths_ = Seqlen(input_lengths=input_lengths)
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1106,7 +1183,7 @@ class FlashCausalLM(Model):
|
||||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
state = create_decode_state_cuda_graphs(
|
state = create_decode_state_cuda_graphs(
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
block_tables=block_tables.view(-1),
|
block_tables=block_tables,
|
||||||
block_tables_ptr=block_tables_ptr,
|
block_tables_ptr=block_tables_ptr,
|
||||||
last_page_len=last_page_len,
|
last_page_len=last_page_len,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
|
@ -1122,7 +1199,10 @@ class FlashCausalLM(Model):
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
state=state,
|
state=state,
|
||||||
|
prefix_lens=prefix_lengths,
|
||||||
|
prefix_lens_tensor=prefix_lengths_tensor,
|
||||||
):
|
):
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -1131,7 +1211,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths_,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -1140,7 +1220,13 @@ class FlashCausalLM(Model):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1148,7 +1234,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -1195,7 +1281,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
int((free_memory * 0.95) // total_cache_size)
|
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
|
||||||
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ batch_num_blocks
|
+ batch_num_blocks
|
||||||
)
|
)
|
||||||
|
@ -1287,18 +1373,26 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
cu_seqlen_prefill = torch.tensor(
|
||||||
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=1,
|
||||||
|
max_k=seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=torch.tensor(
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
|
||||||
),
|
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -1336,6 +1430,9 @@ class FlashCausalLM(Model):
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
|
@ -1356,6 +1453,7 @@ class FlashCausalLM(Model):
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
@ -1374,12 +1472,28 @@ class FlashCausalLM(Model):
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
|
input_lengths_tensor=input_lengths + prefix_lens_tensor,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=max_s,
|
||||||
|
max_k=max_k,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1387,7 +1501,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
|
@ -1401,20 +1515,32 @@ class FlashCausalLM(Model):
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
cuda_graph["block_tables"][
|
if ATTENTION == "flashinfer":
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
block_tables = block_tables_to_ragged(
|
||||||
] = block_tables
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
state = cuda_graph.get("state")
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=cuda_graph["block_tables"],
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
state=state,
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
|
state=cuda_graph.get("state"),
|
||||||
):
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
@ -1612,6 +1738,7 @@ class FlashCausalLM(Model):
|
||||||
batch.read_offsets,
|
batch.read_offsets,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
batch.prefix_ids,
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
|
@ -1629,6 +1756,7 @@ class FlashCausalLM(Model):
|
||||||
read_offset,
|
read_offset,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
prefix_ids,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
@ -1703,18 +1831,18 @@ class FlashCausalLM(Model):
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
# Remove generated token to only have prefill and add nan for first prompt token
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
request_prefill_logprobs = [float("nan")] + prefill_logprobs[
|
request_prefill_logprobs = (
|
||||||
out_start_index : out_end_index - 1
|
[float("nan")] * (len(prefix_ids) + 1)
|
||||||
]
|
) + prefill_logprobs[out_start_index : out_end_index - 1]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill_tokens = Tokens(
|
prefill_tokens = Tokens(
|
||||||
prefill_token_ids,
|
prefix_ids + prefill_token_ids,
|
||||||
request_prefill_logprobs,
|
request_prefill_logprobs,
|
||||||
prefill_texts,
|
prefill_texts,
|
||||||
is_special=[],
|
is_special=[],
|
||||||
|
@ -1796,33 +1924,68 @@ class FlashCausalLM(Model):
|
||||||
*,
|
*,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: List[int],
|
||||||
|
input_lengths_tensor: torch.Tensor,
|
||||||
|
prefix_lens: List[int],
|
||||||
|
prefix_lens_tensor: torch.Tensor,
|
||||||
state: Optional[Any] = None,
|
state: Optional[Any] = None,
|
||||||
) -> ContextManager:
|
) -> ContextManager:
|
||||||
if ATTENTION != "flashinfer":
|
if ATTENTION != "flashinfer":
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
||||||
from text_generation_server.layers.attention.flash_infer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
use_decode_state,
|
use_decode_state,
|
||||||
use_prefill_state,
|
use_prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
return use_prefill_state(
|
return use_prefill_with_paged_kv_state(
|
||||||
state=state if state is not None else self.prefill_state,
|
state=(
|
||||||
|
state if state is not None else self.prefill_with_paged_kv_state
|
||||||
|
),
|
||||||
|
# block_tables=block_tables_to_ragged(
|
||||||
|
# block_tables=block_tables,
|
||||||
|
# input_lengths=input_lengths,
|
||||||
|
# prefix_lens=prefix_lens,
|
||||||
|
# ),
|
||||||
|
block_tables=block_tables,
|
||||||
cu_seqlens=cu_seqlen_prefill,
|
cu_seqlens=cu_seqlen_prefill,
|
||||||
num_heads=self.num_heads,
|
input_lengths=input_lengths_tensor,
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
head_size=self.head_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert input_lengths is not None
|
|
||||||
return use_decode_state(
|
|
||||||
state=state if state is not None else self.decode_state,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
block_tables=block_tables.view(-1),
|
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
head_size=self.head_size,
|
head_size=self.head_size,
|
||||||
page_size=BLOCK_SIZE,
|
page_size=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert input_lengths_tensor is not None
|
||||||
|
return use_decode_state(
|
||||||
|
state=state if state is not None else self.decode_state,
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
block_tables=block_tables,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
page_size=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def block_tables_to_ragged(
|
||||||
|
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||||
|
assert len(input_lengths) == len(prefix_lens)
|
||||||
|
|
||||||
|
total_len = sum(input_lengths) + sum(prefix_lens)
|
||||||
|
block_tables_ragged = torch.empty(
|
||||||
|
total_len, dtype=torch.int32, device=block_tables.device
|
||||||
|
)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
|
||||||
|
seq_len = prefix_len + input_length
|
||||||
|
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
return block_tables_ragged
|
||||||
|
|
|
@ -5,20 +5,22 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
|
ATTENTION = os.getenv("ATTENTION")
|
||||||
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
if PREFIX_CACHING and ATTENTION != "flashinfer":
|
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
|
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
||||||
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
BLOCK_SIZE: int
|
BLOCK_SIZE: int
|
||||||
|
@ -29,7 +31,6 @@ elif ATTENTION == "flashinfer":
|
||||||
else:
|
else:
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||||
if cuda_graphs is not None:
|
if cuda_graphs is not None:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.flash_causal_lm import (
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
FlashCausalLMBatch,
|
FlashCausalLMBatch,
|
||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
|
block_tables_to_ragged,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if PREFIX_CACHING:
|
||||||
|
raise NotImplementedError("Vlm do not work with prefix caching yet")
|
||||||
if processor_kwargs is None:
|
if processor_kwargs is None:
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
self.processor = processor_class.from_pretrained(
|
self.processor = processor_class.from_pretrained(
|
||||||
|
@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
|
@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
@ -349,43 +357,75 @@ class VlmCausalLM(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = input_lengths + prefix_lens_tensor
|
||||||
logits, speculative_logits = self.model.forward(
|
if PREFIX_CACHING:
|
||||||
input_ids=input_ids,
|
block_tables = block_tables_to_ragged(
|
||||||
position_ids=position_ids,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
input_lengths=batch.input_lengths,
|
||||||
kv_cache=kv_cache,
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths=input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
max_s=max_s,
|
input_lengths_tensor=input_lengths,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefix_lens=batch.prefix_lens,
|
||||||
lm_head_indices=lm_head_indices,
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
pixel_values=batch.pixel_values,
|
):
|
||||||
pixel_attention_mask=batch.pixel_attention_mask,
|
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||||
image_sizes=batch.image_sizes,
|
seqlen = Seqlen(
|
||||||
)
|
input_lengths=input_lengths,
|
||||||
if batch.prefill_cache_indices is not None:
|
prefix_lengths=prefix_lens_tensor,
|
||||||
batch.prefill_cache_indices = None
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
if batch.pixel_values is not None:
|
max_q=max_s,
|
||||||
batch.pixel_values = None
|
max_k=max_k,
|
||||||
if batch.pixel_attention_mask is not None:
|
)
|
||||||
batch.pixel_attention_mask = None
|
logits, speculative_logits = self.model.forward(
|
||||||
if batch.image_sizes is not None:
|
input_ids=input_ids,
|
||||||
batch.image_sizes = None
|
position_ids=position_ids,
|
||||||
return logits, speculative_logits
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
pixel_values=batch.pixel_values,
|
||||||
|
pixel_attention_mask=batch.pixel_attention_mask,
|
||||||
|
image_sizes=batch.image_sizes,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
if batch.pixel_values is not None:
|
||||||
|
batch.pixel_values = None
|
||||||
|
if batch.pixel_attention_mask is not None:
|
||||||
|
batch.pixel_attention_mask = None
|
||||||
|
if batch.image_sizes is not None:
|
||||||
|
batch.image_sizes = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||||
cuda_graph["block_tables"][
|
if ATTENTION == "flashinfer":
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
block_tables = block_tables_to_ragged(
|
||||||
] = block_tables
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
prefix_lens=batch.prefix_lens,
|
||||||
|
)
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
cuda_graph["slots"].fill_(-1)
|
cuda_graph["slots"].fill_(-1)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
|
input_lengths + prefix_lens_tensor
|
||||||
|
)
|
||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
# License: Apache License Version 2.0, January 2004
|
# License: Apache License Version 2.0, January 2004
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
from typing import TYPE_CHECKING, Set, Tuple, Optional, List
|
||||||
|
@ -27,6 +28,7 @@ BASE_MODEL_ADAPTER_ID = "__base_model__"
|
||||||
class AdapterInfo:
|
class AdapterInfo:
|
||||||
id: str
|
id: str
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -51,11 +53,16 @@ def parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:
|
||||||
|
|
||||||
adapter_list = []
|
adapter_list = []
|
||||||
for adapter in lora_adapters.split(","):
|
for adapter in lora_adapters.split(","):
|
||||||
parts = adapter.strip().split("=")
|
adapter = adapter.strip()
|
||||||
if len(parts) == 1:
|
if adapter.count("=") > 1 or adapter.count("@") > 1:
|
||||||
adapter_list.append(AdapterInfo(id=parts[0], path=None))
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||||
elif len(parts) == 2:
|
match = re.match(r"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$", adapter)
|
||||||
adapter_list.append(AdapterInfo(id=parts[0], path=parts[1]))
|
|
||||||
|
if match:
|
||||||
|
adapter_id, path, revision = match.groups()
|
||||||
|
adapter_list.append(
|
||||||
|
AdapterInfo(id=adapter_id, path=path, revision=revision)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
raise ValueError(f"Invalid LoRA adapter format: {adapter}")
|
||||||
return adapter_list
|
return adapter_list
|
||||||
|
@ -73,6 +80,7 @@ def load_and_merge_adapters(
|
||||||
adapter_info = next(iter(adapter_parameters.adapter_info))
|
adapter_info = next(iter(adapter_parameters.adapter_info))
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
|
adapter_info.revision,
|
||||||
adapter_info.id,
|
adapter_info.id,
|
||||||
adapter_info.path,
|
adapter_info.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
|
@ -80,7 +88,13 @@ def load_and_merge_adapters(
|
||||||
)
|
)
|
||||||
|
|
||||||
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||||
return _load_and_merge(model_id, adapter_params, weight_names, trust_remote_code)
|
return _load_and_merge(
|
||||||
|
model_id,
|
||||||
|
adapter_params.revision,
|
||||||
|
adapter_params,
|
||||||
|
weight_names,
|
||||||
|
trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -95,6 +109,7 @@ class AdapterParametersContainer:
|
||||||
@lru_cache(maxsize=32)
|
@lru_cache(maxsize=32)
|
||||||
def _load_and_merge(
|
def _load_and_merge(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
revision: str,
|
||||||
adapter_params: AdapterParametersContainer,
|
adapter_params: AdapterParametersContainer,
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
@ -171,12 +186,12 @@ def check_architectures(
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def load_module_map(
|
def load_module_map(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
revision: str,
|
||||||
adapter_id: str,
|
adapter_id: str,
|
||||||
adapter_path: Optional[str],
|
adapter_path: Optional[str],
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
revision = "main"
|
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||||
|
|
||||||
|
@ -191,6 +206,12 @@ def load_module_map(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# throw an error if no adapter weights are found
|
||||||
|
if not adapter_filenames:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
adapter_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
adapter_config.config_path,
|
adapter_config.config_path,
|
||||||
|
@ -221,6 +242,12 @@ def get_attn_weights(i, layer):
|
||||||
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
value = (f"model.layers.{i}.self_attn.{k}_proj", qkv)
|
||||||
weights[key] = value
|
weights[key] = value
|
||||||
|
|
||||||
|
# also add the qkv_proj weight for the adapter
|
||||||
|
weights[(i, "qkv_proj")] = (
|
||||||
|
f"model.layers.{i}.self_attn.qkv_proj",
|
||||||
|
qkv,
|
||||||
|
)
|
||||||
|
|
||||||
weights[(i, "o_proj")] = (
|
weights[(i, "o_proj")] = (
|
||||||
f"model.layers.{i}.self_attn.o_proj",
|
f"model.layers.{i}.self_attn.o_proj",
|
||||||
layer.self_attn.o_proj,
|
layer.self_attn.o_proj,
|
||||||
|
|
Loading…
Reference in New Issue