Impl simple mamba model (#1480)
This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass. This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations. #### Helpful resources [Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752) https://github.com/johnma2006/mamba-minimal https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs https://github.com/huggingface/transformers/pull/28094 Notes: this dev work is currently targeting `state-spaces/mamba-130m`, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: `cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768` ## Update / Current State Integration tests have been added and basic functionality such as model loading is supported. ```bash cd integration-tests pytest -vv models/test_fused_kernel_mamba.py ``` - [x] add tests - [x] load model - [x] make simple request - [ ] resolve warmup issue - [ ] resolve output issues fetching models tested during dev ```bash text-generation-server download-weights state-spaces/mamba-130m text-generation-server download-weights state-spaces/mamba-1.4b text-generation-server download-weights state-spaces/mamba-2.8b ``` The server can be run ```bash cd server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b ``` router ```bash cargo run ``` make a request ```bash curl -s localhost:3000/generate \ -X POST \ -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ -H 'Content-Type: application/json' | jq ``` response ```json { "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data." } ``` --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
1734540211
commit
bd405e035b
10
Dockerfile
10
Dockerfile
|
@ -154,6 +154,12 @@ COPY server/Makefile-vllm Makefile
|
||||||
# Build specific version of vllm
|
# Build specific version of vllm
|
||||||
RUN make build-vllm-cuda
|
RUN make build-vllm-cuda
|
||||||
|
|
||||||
|
# Build mamba kernels
|
||||||
|
FROM kernel-builder as mamba-builder
|
||||||
|
WORKDIR /usr/src
|
||||||
|
COPY server/Makefile-selective-scan Makefile
|
||||||
|
RUN make build-all
|
||||||
|
|
||||||
# Build megablocks
|
# Build megablocks
|
||||||
FROM kernel-builder as megablocks-builder
|
FROM kernel-builder as megablocks-builder
|
||||||
|
|
||||||
|
@ -205,6 +211,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31
|
||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from mamba builder
|
||||||
|
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||||
|
|
||||||
# Install flash-attention dependencies
|
# Install flash-attention dependencies
|
||||||
RUN pip install einops --no-cache-dir
|
RUN pip install einops --no-cache-dir
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.3552246,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38378906,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.140625,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.5551758,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59033203,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.70654297,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0410156,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0026435852,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2841797,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 2502,
|
||||||
|
"logprob": null,
|
||||||
|
"text": " red"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.5234375,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8862,
|
||||||
|
"logprob": -3.4433594,
|
||||||
|
"text": " yellow"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.43017578,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 209,
|
||||||
|
"logprob": -8.21875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 395,
|
||||||
|
"logprob": -0.46411133,
|
||||||
|
"special": false,
|
||||||
|
"text": "and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13735,
|
||||||
|
"logprob": -2.1132812,
|
||||||
|
"special": false,
|
||||||
|
"text": " orange"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 313,
|
||||||
|
"logprob": -1.2128906,
|
||||||
|
"special": false,
|
||||||
|
"text": " ("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 249,
|
||||||
|
"logprob": -2.3671875,
|
||||||
|
"special": false,
|
||||||
|
"text": "in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1340,
|
||||||
|
"logprob": -1.640625,
|
||||||
|
"special": false,
|
||||||
|
"text": " order"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 597,
|
||||||
|
"logprob": -0.5488281,
|
||||||
|
"special": false,
|
||||||
|
"text": " they"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3176,
|
||||||
|
"logprob": -0.48608398,
|
||||||
|
"special": false,
|
||||||
|
"text": " appear"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 275,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "blue, red, yellow, \nand orange (in the order they appear in"
|
||||||
|
}
|
|
@ -0,0 +1,398 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.8125,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.828125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -3.0,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1484375,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.3552246,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38378906,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1279297,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.5595703,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.60253906,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7050781,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0488281,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3808594,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0026416779,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.78027344,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.35351562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38256836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.54541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0027446747,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.78027344,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.35351562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38256836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.54541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0027446747,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1276,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.78027344,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18147,
|
||||||
|
"logprob": -12.8203125,
|
||||||
|
"text": " Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20727,
|
||||||
|
"logprob": -2.9902344,
|
||||||
|
"text": " Learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -1.1523438,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.35351562,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 187,
|
||||||
|
"logprob": -0.38256836,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30763,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4715,
|
||||||
|
"logprob": -0.54541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 310,
|
||||||
|
"logprob": -0.59765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 247,
|
||||||
|
"logprob": -0.7001953,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 747,
|
||||||
|
"logprob": -2.0585938,
|
||||||
|
"special": false,
|
||||||
|
"text": " new"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1511,
|
||||||
|
"logprob": -2.3789062,
|
||||||
|
"special": false,
|
||||||
|
"text": " type"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 273,
|
||||||
|
"logprob": -0.0027446747,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5145,
|
||||||
|
"logprob": -1.2851562,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nDeep learning is a new type of machine"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,59 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def fused_kernel_mamba_handle(launcher):
|
||||||
|
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||||
|
await fused_kernel_mamba_handle.health(300)
|
||||||
|
return fused_kernel_mamba_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
|
response = await fused_kernel_mamba.generate(
|
||||||
|
"What is Deep Learning?", max_new_tokens=10
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == "\n\nDeep learning is a new type of machine"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
|
response = await fused_kernel_mamba.generate(
|
||||||
|
"blue, red, yellow, ",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -161,3 +161,4 @@ flash-attention-v2/
|
||||||
vllm/
|
vllm/
|
||||||
llm-awq/
|
llm-awq/
|
||||||
eetq/
|
eetq/
|
||||||
|
mamba/
|
||||||
|
|
|
@ -3,6 +3,7 @@ include Makefile-flash-att-v2
|
||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
include Makefile-awq
|
include Makefile-awq
|
||||||
include Makefile-eetq
|
include Makefile-eetq
|
||||||
|
include Makefile-selective-scan
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
pytest -s -vv -m "not private" tests
|
pytest -s -vv -m "not private" tests
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137
|
||||||
|
|
||||||
|
causal-conv1d:
|
||||||
|
rm -rf causal-conv1d
|
||||||
|
git clone https://github.com/Dao-AILab/causal-conv1d.git
|
||||||
|
|
||||||
|
build-causal-conv1d: causal-conv1d
|
||||||
|
cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag
|
||||||
|
cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build
|
||||||
|
|
||||||
|
install-causal-conv1d: build-causal-conv1d
|
||||||
|
pip uninstall causal-conv1d -y || true
|
||||||
|
cd causal-conv1d/ && pip install .
|
||||||
|
|
||||||
|
# selective-scan dependends on causal-conv1d
|
||||||
|
selective-scan:
|
||||||
|
rm -rf mamba
|
||||||
|
git clone https://github.com/state-spaces/mamba.git mamba
|
||||||
|
|
||||||
|
build-selective-scan: selective-scan
|
||||||
|
cd mamba/ && git fetch && git checkout $(selective_scan_commit)
|
||||||
|
cd mamba && python setup.py build
|
||||||
|
|
||||||
|
install-selective-scan: install-causal-conv1d build-selective-scan
|
||||||
|
pip uninstall selective-scan-cuda -y || true
|
||||||
|
cd mamba && pip install .
|
||||||
|
|
||||||
|
build-all: build-causal-conv1d build-selective-scan
|
|
@ -76,6 +76,15 @@ if FLASH_ATTENTION:
|
||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
__all__.append(FlashPhi)
|
__all__.append(FlashPhi)
|
||||||
|
|
||||||
|
MAMBA_AVAILABLE = True
|
||||||
|
try:
|
||||||
|
from text_generation_server.models.mamba import Mamba
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Could not import Mamba: {e}")
|
||||||
|
MAMBA_AVAILABLE = False
|
||||||
|
|
||||||
|
if MAMBA_AVAILABLE:
|
||||||
|
__all__.append(Mamba)
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -164,7 +173,25 @@ def get_model(
|
||||||
if speculate > 0:
|
if speculate > 0:
|
||||||
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict.get("model_type", None)
|
||||||
|
if model_type is None:
|
||||||
|
# TODO: fix how we determine model type for Mamba
|
||||||
|
if "ssm_cfg" in config_dict:
|
||||||
|
# *only happens in Mamba case
|
||||||
|
model_type = "ssm"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not determine model type for {model_id} revision {revision}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "ssm":
|
||||||
|
return Mamba(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
|
|
@ -0,0 +1,194 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
||||||
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
||||||
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
|
from torch import nn
|
||||||
|
from typing import Optional, Tuple, Any
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
FastRMSNorm,
|
||||||
|
FastLinear,
|
||||||
|
)
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
||||||
|
import math
|
||||||
|
|
||||||
|
class MambaConfig(PretrainedConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50280,
|
||||||
|
d_model=768,
|
||||||
|
d_state=16,
|
||||||
|
n_layer=32,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
expand=2,
|
||||||
|
dt_rank="auto",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.d_model = d_model
|
||||||
|
self.d_inner = d_model * 2
|
||||||
|
self.d_conv = 4
|
||||||
|
self.d_state = d_state
|
||||||
|
self.expand = expand
|
||||||
|
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MambaBlock(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = int(prefix.split(".")[2])
|
||||||
|
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
|
||||||
|
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
|
||||||
|
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
|
||||||
|
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False)
|
||||||
|
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False)
|
||||||
|
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
|
||||||
|
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
|
||||||
|
self.D = weights.get_tensor(f"{prefix}.D")
|
||||||
|
self.activation = "silu"
|
||||||
|
self.dt_rank = config.dt_rank
|
||||||
|
self.d_state = config.d_state
|
||||||
|
self.d_conv = config.d_conv
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
# inference_params
|
||||||
|
def forward(self, hidden_states: torch.Tensor, inference_params=None):
|
||||||
|
_, seqlen, _ = hidden_states.shape
|
||||||
|
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
|
||||||
|
|
||||||
|
if inference_params.seqlen_offset > 0:
|
||||||
|
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
|
||||||
|
return out, conv_state, ssm_state
|
||||||
|
|
||||||
|
projected_states = self.in_proj(hidden_states).transpose(1,2)
|
||||||
|
x, z = projected_states.chunk(2, dim=1)
|
||||||
|
conv_state = F.pad(x, (self.d_conv - seqlen, 0))
|
||||||
|
x = causal_conv1d_fn(
|
||||||
|
x=x,
|
||||||
|
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We're careful here about the layout, to avoid extra transposes.
|
||||||
|
# We want dt to have d as the slowest moving dimension
|
||||||
|
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||||
|
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||||
|
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = self.dt_proj.weight @ dt.t()
|
||||||
|
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||||
|
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||||
|
y, last_state = selective_scan_fn(
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
self.negA,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
self.D.float(),
|
||||||
|
z=z,
|
||||||
|
delta_bias=self.dt_proj.bias.float(),
|
||||||
|
delta_softplus=True,
|
||||||
|
return_last_state=True,
|
||||||
|
)
|
||||||
|
y = rearrange(y, "b d l -> b l d")
|
||||||
|
attn_outputs = self.out_proj(y)
|
||||||
|
return attn_outputs, conv_state, last_state
|
||||||
|
|
||||||
|
def step(self, hidden_states, conv_state, ssm_state):
|
||||||
|
_xz = self.in_proj(hidden_states)
|
||||||
|
_x, _z = _xz.chunk(2, dim=-1) # (B D)
|
||||||
|
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1)
|
||||||
|
conv_out = causal_conv1d_fn(
|
||||||
|
x=conv_state_new,
|
||||||
|
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)),
|
||||||
|
bias=self.conv1d.bias,
|
||||||
|
activation=self.activation
|
||||||
|
)
|
||||||
|
conv_state = conv_state_new[:, :, 1:]
|
||||||
|
bsz, seqlen, dim = hidden_states.shape
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(bsz, seqlen, dim),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
for i in range(0, bsz):
|
||||||
|
x = conv_out[i:i+1,:,-1]
|
||||||
|
z = _z[i:i+1, -1, :]
|
||||||
|
x_db = self.x_proj(x)
|
||||||
|
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
||||||
|
dt = F.linear(dt, self.dt_proj.weight)
|
||||||
|
y = selective_state_update(
|
||||||
|
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
|
||||||
|
)
|
||||||
|
out = self.out_proj(y)
|
||||||
|
output_tensor[i] = out
|
||||||
|
|
||||||
|
return output_tensor, conv_state, ssm_state
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights)
|
||||||
|
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
inference_params: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||||
|
shape = residual.shape
|
||||||
|
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
|
||||||
|
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params)
|
||||||
|
return hidden_states, residual, conv_state, last_ssm_state
|
||||||
|
|
||||||
|
class MambaModel(nn.Module):
|
||||||
|
def __init__(self, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
prefix = "backbone"
|
||||||
|
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)]
|
||||||
|
)
|
||||||
|
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
|
||||||
|
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for block in self.blocks:
|
||||||
|
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params)
|
||||||
|
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + residual if residual is not None else hidden_states
|
||||||
|
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||||
|
hidden_states = hidden_states.view(residual.shape)
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
# update the offset for the next inference using these params
|
||||||
|
inference_params.seqlen_offset += input_ids.size(1)
|
||||||
|
return logits, input_ids, inference_params
|
|
@ -0,0 +1,656 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
from typing import Optional
|
||||||
|
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||||
|
MambaConfig,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import (
|
||||||
|
initialize_torch_distributed,
|
||||||
|
weight_files,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel
|
||||||
|
from text_generation_server.models import Model
|
||||||
|
from typing import Any, List, Optional, Tuple, Type, Dict
|
||||||
|
from text_generation_server.models.types import (
|
||||||
|
Batch,
|
||||||
|
Tokens,
|
||||||
|
Generation,
|
||||||
|
GeneratedText,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
from mamba_ssm.utils.generation import InferenceParams
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaBatch(Batch):
|
||||||
|
batch_id: int
|
||||||
|
requests: List[generate_pb2.Request]
|
||||||
|
requests_idx_mapping: Dict[int, int]
|
||||||
|
|
||||||
|
# Decoder values
|
||||||
|
input_ids: torch.Tensor
|
||||||
|
|
||||||
|
# All tokens
|
||||||
|
all_input_ids: List[torch.Tensor]
|
||||||
|
|
||||||
|
# Lengths of all generations present in the batch
|
||||||
|
input_lengths: List[int]
|
||||||
|
prefix_offsets: List[int]
|
||||||
|
read_offsets: List[int]
|
||||||
|
|
||||||
|
# Generation helpers
|
||||||
|
next_token_choosers: List[NextTokenChooser]
|
||||||
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
top_n_tokens: List[int]
|
||||||
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
|
# Metadata used for padding
|
||||||
|
max_input_length: int
|
||||||
|
padding_right_offset: int
|
||||||
|
|
||||||
|
# Maximum number of tokens this batch will grow to
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
# Past metadata
|
||||||
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
|
# Inference params
|
||||||
|
inference_params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
|
return generate_pb2.CachedBatch(
|
||||||
|
id=self.batch_id,
|
||||||
|
request_ids=[r.id for r in self.requests],
|
||||||
|
size=len(self),
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "MambaBatch":
|
||||||
|
inputs = []
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
# Parse batch
|
||||||
|
max_truncation = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
max_decode_tokens = 0
|
||||||
|
for i, r in enumerate(pb.requests):
|
||||||
|
requests_idx_mapping[r.id] = i
|
||||||
|
inputs.append(r.inputs)
|
||||||
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
||||||
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
|
r.stopping_parameters, tokenizer
|
||||||
|
)
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(r.top_n_tokens)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
max_decode_tokens += stopping_criteria.max_new_tokens
|
||||||
|
padding_right_offset = max(
|
||||||
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_inputs = tokenizer(
|
||||||
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
|
).to(device)
|
||||||
|
for _ in pb.requests:
|
||||||
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
|
prefix_offsets.append(input_len - 5)
|
||||||
|
read_offsets.append(input_len)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
|
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
|
||||||
|
top_n_tokens_tensor = torch.tensor(
|
||||||
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
return cls(
|
||||||
|
batch_id=pb.id,
|
||||||
|
requests=pb.requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
# past_input_ids=None,
|
||||||
|
all_input_ids=list(all_input_ids),
|
||||||
|
input_lengths=input_lengths.tolist(),
|
||||||
|
prefix_offsets=prefix_offsets,
|
||||||
|
read_offsets=read_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
max_input_length=max_input_length.item(),
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]:
|
||||||
|
if len(request_ids) == 0:
|
||||||
|
raise ValueError("Batch must have at least one request")
|
||||||
|
if len(request_ids) == len(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
keep_indices = []
|
||||||
|
|
||||||
|
# New values after filtering
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
requests = []
|
||||||
|
input_lengths = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
all_input_ids = []
|
||||||
|
max_input_length = 0
|
||||||
|
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
|
||||||
|
total_remaining_decode_tokens = 0
|
||||||
|
new_padding_right_offset = 0
|
||||||
|
|
||||||
|
indices = []
|
||||||
|
for i, request_id in enumerate(request_ids):
|
||||||
|
idx = self.requests_idx_mapping[request_id]
|
||||||
|
requests_idx_mapping[request_id] = i
|
||||||
|
keep_indices.append(idx)
|
||||||
|
|
||||||
|
requests.append(self.requests[idx])
|
||||||
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
|
||||||
|
request_input_length = self.input_lengths[idx]
|
||||||
|
input_lengths.append(request_input_length)
|
||||||
|
max_input_length = max(max_input_length, request_input_length)
|
||||||
|
indices.append(idx)
|
||||||
|
|
||||||
|
next_token_choosers.append(self.next_token_choosers[idx])
|
||||||
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
top_n_tokens.append(self.top_n_tokens[idx])
|
||||||
|
remaining_decode_tokens = (
|
||||||
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
|
)
|
||||||
|
total_remaining_decode_tokens += remaining_decode_tokens
|
||||||
|
new_padding_right_offset = max(
|
||||||
|
new_padding_right_offset, remaining_decode_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
|
||||||
|
input_ids = self.input_ids[keep_indices]
|
||||||
|
|
||||||
|
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
|
||||||
|
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
|
||||||
|
|
||||||
|
self.requests = requests
|
||||||
|
self.requests_idx_mapping = requests_idx_mapping
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.all_input_ids = all_input_ids
|
||||||
|
self.input_lengths = input_lengths
|
||||||
|
self.prefix_offsets = prefix_offsets
|
||||||
|
self.read_offsets = read_offsets
|
||||||
|
self.next_token_choosers = next_token_choosers
|
||||||
|
self.stopping_criterias = stopping_criterias
|
||||||
|
self.top_n_tokens = top_n_tokens
|
||||||
|
self.top_n_tokens_tensor = top_n_tokens_tensor
|
||||||
|
self.max_input_length = max_input_length
|
||||||
|
self.padding_right_offset = new_padding_right_offset
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
|
||||||
|
key_value_memory_dict = {}
|
||||||
|
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items():
|
||||||
|
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
|
||||||
|
self.inference_params.key_value_memory_dict = key_value_memory_dict
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
|
||||||
|
# Used for padding
|
||||||
|
total_batch_size = 0
|
||||||
|
max_input_length = 0
|
||||||
|
padding_right_offset = 0
|
||||||
|
for batch in batches:
|
||||||
|
total_batch_size += len(batch)
|
||||||
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
|
||||||
|
# Batch attributes
|
||||||
|
requests = []
|
||||||
|
requests_idx_mapping = {}
|
||||||
|
input_lengths = []
|
||||||
|
prefix_offsets = []
|
||||||
|
read_offsets = []
|
||||||
|
all_input_ids = []
|
||||||
|
next_token_choosers = []
|
||||||
|
stopping_criterias = []
|
||||||
|
top_n_tokens = []
|
||||||
|
max_tokens = 0
|
||||||
|
max_seqlen = 0
|
||||||
|
batch_size = 0
|
||||||
|
seqlen_offset = 0
|
||||||
|
|
||||||
|
# Batch tensors
|
||||||
|
input_ids = None
|
||||||
|
top_n_tokens_tensor = None
|
||||||
|
|
||||||
|
# Used for slicing correctly inside the tensors
|
||||||
|
# Equivalent to a cumsum on batch sizes
|
||||||
|
start_index = 0
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
requests.extend(batch.requests)
|
||||||
|
input_lengths.extend(batch.input_lengths)
|
||||||
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
|
read_offsets.extend(batch.read_offsets)
|
||||||
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
next_token_choosers.extend(batch.next_token_choosers)
|
||||||
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
top_n_tokens.extend(batch.top_n_tokens)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
requests_idx_mapping = batch.requests_idx_mapping
|
||||||
|
else:
|
||||||
|
# We need to offset the mapping for each batch by the cumulative batch size
|
||||||
|
for k, v in batch.requests_idx_mapping.items():
|
||||||
|
requests_idx_mapping[k] = v + start_index
|
||||||
|
|
||||||
|
# Slicing end index for this batch
|
||||||
|
end_index = start_index + len(batch)
|
||||||
|
|
||||||
|
# Create empty tensor
|
||||||
|
# input_ids is always of shape [batch_size, 1]
|
||||||
|
# We do not need to pad it
|
||||||
|
if input_ids is None:
|
||||||
|
input_ids = batch.input_ids.new_empty((total_batch_size, 1))
|
||||||
|
# Copy to correct indices
|
||||||
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
|
|
||||||
|
if top_n_tokens_tensor is None:
|
||||||
|
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
|
||||||
|
total_batch_size,
|
||||||
|
)
|
||||||
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
|
|
||||||
|
# Add eventual padding tokens that were added while concatenating
|
||||||
|
max_tokens += batch.max_tokens + (
|
||||||
|
max_input_length - batch.max_input_length
|
||||||
|
) * len(batch)
|
||||||
|
|
||||||
|
max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen)
|
||||||
|
seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset)
|
||||||
|
batch_size += batch.inference_params.max_batch_size
|
||||||
|
|
||||||
|
start_index = end_index
|
||||||
|
|
||||||
|
|
||||||
|
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape
|
||||||
|
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
|
||||||
|
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
|
||||||
|
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
|
||||||
|
device = batches[0].inference_params.key_value_memory_dict[0][0].device
|
||||||
|
|
||||||
|
key_value_memory_dict = {}
|
||||||
|
for i in range(n_blocks):
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
d_model,
|
||||||
|
d_conv,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
ssm_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
d_model,
|
||||||
|
d_state,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
key_value_memory_dict[i] = (conv_state, ssm_state)
|
||||||
|
lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
inference_params = InferenceParams(
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
seqlen_offset=seqlen_offset,
|
||||||
|
key_value_memory_dict=key_value_memory_dict,
|
||||||
|
lengths_per_sample=lengths_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_batch = 0
|
||||||
|
for batch in batches:
|
||||||
|
for i in range(n_blocks):
|
||||||
|
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
|
||||||
|
batch_size = batch.inference_params.max_batch_size
|
||||||
|
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state
|
||||||
|
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state
|
||||||
|
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample
|
||||||
|
current_batch += batch_size
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
batch_id=batches[0].batch_id,
|
||||||
|
requests=requests,
|
||||||
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
|
input_ids=input_ids,
|
||||||
|
all_input_ids=all_input_ids,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_offsets=prefix_offsets,
|
||||||
|
read_offsets=read_offsets,
|
||||||
|
next_token_choosers=next_token_choosers,
|
||||||
|
stopping_criterias=stopping_criterias,
|
||||||
|
top_n_tokens=top_n_tokens,
|
||||||
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
|
max_input_length=max_input_length,
|
||||||
|
padding_right_offset=padding_right_offset,
|
||||||
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
inference_params=inference_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.requests)
|
||||||
|
|
||||||
|
class Mamba(Model):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"EleutherAI/gpt-neox-20b",
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
config = MambaConfig.from_pretrained(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.bos_token_id = config.bos_token_id
|
||||||
|
tokenizer.eos_token_id = config.eos_token_id
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
config.quantize = quantize
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
model = MambaModel(config, weights)
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
super(Mamba, self).__init__(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_type(self) -> Type[MambaBatch]:
|
||||||
|
return MambaBatch
|
||||||
|
|
||||||
|
def warmup(self, batch) -> Optional[int]:
|
||||||
|
# TODO: implement warmup for Mamba if needed
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
past: Optional[List[torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.model(
|
||||||
|
input_ids,
|
||||||
|
past=past,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
|
start = time.time_ns()
|
||||||
|
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
max_seqlen = input_ids.shape[1]
|
||||||
|
dtype = input_ids.dtype
|
||||||
|
|
||||||
|
# Inference params
|
||||||
|
seqlen_og = 0
|
||||||
|
inf_cache = {}
|
||||||
|
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen
|
||||||
|
|
||||||
|
if batch.inference_params is None:
|
||||||
|
inference_params = InferenceParams(
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
max_batch_size=batch_size,
|
||||||
|
seqlen_offset=seqlen_og,
|
||||||
|
key_value_memory_dict=inf_cache,
|
||||||
|
lengths_per_sample=lengths_per_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate inference cache
|
||||||
|
for res_block in self.model.blocks:
|
||||||
|
block = res_block.mamba_block
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.model.config.d_model * self.model.config.expand,
|
||||||
|
self.model.config.d_conv,
|
||||||
|
device=block.conv1d.weight.device,
|
||||||
|
dtype=block.conv1d.weight.dtype,
|
||||||
|
)
|
||||||
|
ssm_state = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
self.model.config.d_model * self.model.config.expand,
|
||||||
|
self.model.config.d_state,
|
||||||
|
device=block.dt_proj.weight.device,
|
||||||
|
dtype=block.dt_proj.weight.dtype,
|
||||||
|
)
|
||||||
|
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state)
|
||||||
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params)
|
||||||
|
|
||||||
|
batch.inference_params = new_inference_params
|
||||||
|
# Results
|
||||||
|
generations: List[Generation] = []
|
||||||
|
stopped = True
|
||||||
|
|
||||||
|
# Speculation is not active for causal
|
||||||
|
accepted_ids = torch.ones_like(batch.input_ids)[:, 0]
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch.top_n_tokens_tensor,
|
||||||
|
torch.log_softmax(logits[:, -1], -1),
|
||||||
|
accepted_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_decode = time.time_ns()
|
||||||
|
|
||||||
|
# Zipped iterator
|
||||||
|
iterator = zip(
|
||||||
|
batch.requests,
|
||||||
|
batch.input_lengths,
|
||||||
|
batch.prefix_offsets,
|
||||||
|
batch.read_offsets,
|
||||||
|
logits,
|
||||||
|
batch.next_token_choosers,
|
||||||
|
batch.stopping_criterias,
|
||||||
|
batch.all_input_ids,
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch_top_token_ids,
|
||||||
|
batch_top_token_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each member of the batch
|
||||||
|
for i, (
|
||||||
|
request,
|
||||||
|
input_length,
|
||||||
|
prefix_offset,
|
||||||
|
read_offset,
|
||||||
|
logits,
|
||||||
|
next_token_chooser,
|
||||||
|
stopping_criteria,
|
||||||
|
all_input_ids,
|
||||||
|
top_n_tokens,
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
) in enumerate(iterator):
|
||||||
|
# Select next token
|
||||||
|
next_token_id, logprobs = next_token_chooser(
|
||||||
|
all_input_ids.view(1, -1), logits[-1:, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append next token to all tokens
|
||||||
|
all_input_ids = torch.cat([all_input_ids, next_token_id])
|
||||||
|
new_input_length = input_length + 1
|
||||||
|
|
||||||
|
# Generated token
|
||||||
|
next_token_logprob = logprobs[-1, next_token_id]
|
||||||
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
|
all_input_ids[:, 0], prefix_offset, read_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate stopping criteria
|
||||||
|
stop, reason = stopping_criteria(
|
||||||
|
next_token_id_squeezed,
|
||||||
|
next_token_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stop:
|
||||||
|
stopped = False
|
||||||
|
|
||||||
|
# Shard generations
|
||||||
|
# All generations will be appended in the rust sharded client
|
||||||
|
if i % self.world_size == self.rank:
|
||||||
|
if stop:
|
||||||
|
# Decode generated tokens
|
||||||
|
output_text, _, _ = self.decode_token(
|
||||||
|
all_input_ids[:, 0],
|
||||||
|
prefix_offset=len(all_input_ids)
|
||||||
|
- stopping_criteria.current_tokens
|
||||||
|
- 1,
|
||||||
|
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
# Get seed
|
||||||
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
seed = next_token_chooser.choice.seed
|
||||||
|
else:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
generated_text = GeneratedText(
|
||||||
|
output_text, stopping_criteria.current_tokens, reason, seed
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated_text = None
|
||||||
|
|
||||||
|
if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:
|
||||||
|
# Remove generated token to only have prefill and add nan for first prompt token
|
||||||
|
prefill_logprobs = [float("nan")] + torch.log_softmax(
|
||||||
|
logits, -1
|
||||||
|
).gather(1, all_input_ids[1:]).squeeze(1)[
|
||||||
|
-new_input_length:-1
|
||||||
|
].tolist()
|
||||||
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
prefill_tokens = Tokens(
|
||||||
|
prefill_token_ids,
|
||||||
|
prefill_logprobs,
|
||||||
|
prefill_texts,
|
||||||
|
is_special=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_tokens = None
|
||||||
|
|
||||||
|
if top_n_tokens > 0:
|
||||||
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
top_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
special_toptokens = [
|
||||||
|
token_id in self.all_special_ids for token_id in top_token_ids
|
||||||
|
]
|
||||||
|
top_tokens = Tokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
request.id,
|
||||||
|
prefill_tokens,
|
||||||
|
Tokens(
|
||||||
|
[next_token_id_squeezed],
|
||||||
|
[next_token_logprob],
|
||||||
|
[next_token_text],
|
||||||
|
[next_token_id_squeezed.item() in self.all_special_ids],
|
||||||
|
),
|
||||||
|
generated_text,
|
||||||
|
top_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
|
# Update values
|
||||||
|
batch.input_ids[i, 0] = next_token_id
|
||||||
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
batch.input_lengths[i] = new_input_length
|
||||||
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
|
batch.read_offsets[i] = read_offset
|
||||||
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
|
# We finished all generations in the batch; there is no next batch
|
||||||
|
if stopped:
|
||||||
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
|
# Slice unused values from prefill
|
||||||
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
|
||||||
|
forward_ns = start_decode - start
|
||||||
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
return generations, batch, (forward_ns, decode_ns)
|
Loading…
Reference in New Issue