Compare commits
26 Commits
399defbe94
...
bd10f790ed
Author | SHA1 | Date |
---|---|---|
Wang, Yi A | bd10f790ed | |
Wang, Yi A | 0b6c5fea7c | |
Nicolas Patry | 88702d8763 | |
OlivierDehaene | c38a7d7ddd | |
Ikko Eltociear Ashimine | 275caa04b1 | |
OlivierDehaene | eefea5ee31 | |
Nicolas Patry | 1b2670c823 | |
Christof Weickhardt | 9d8f21cace | |
OlivierDehaene | c2c98725f8 | |
Nicolas Patry | 6c2c44b84c | |
Nicolas Patry | 408dbc485c | |
oOraph | c2fd35d875 | |
Nicolas Patry | 842f6658e2 | |
Nicolas Patry | b83aab9bb3 | |
abhishek thakur | 10d9083b2d | |
OlivierDehaene | 30620a9a44 | |
OlivierDehaene | ad9d6288c8 | |
Nicolas Patry | 4634b00c2a | |
Nicolas Patry | 106d8ee818 | |
OlivierDehaene | ff42d33e99 | |
oOraph | 53c2c3dbc7 | |
Nicolas Patry | 8dca3b04f8 | |
Nicolas Patry | f9958ee191 | |
Nicolas Patry | 5062fda4ff | |
Nicolas Patry | c7e570e59d | |
Nicolas Patry | 99874eae74 |
|
@ -13,7 +13,10 @@ jobs:
|
|||
|
||||
- name: Install Launcher
|
||||
id: install-launcher
|
||||
run: cargo install --git https://github.com/${{ github.repository }} --branch ${{ github.head_ref }} text-generation-launcher
|
||||
env:
|
||||
REF: ${{ github.head_ref }}
|
||||
REPO: ${{ github.repository }}
|
||||
run: cargo install --git "https://github.com/$REPO" --branch "$REF" text-generation-launcher
|
||||
|
||||
- name: Check launcher Docs are up-to-date
|
||||
run: |
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,7 +9,7 @@ members = [
|
|||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "1.4.5"
|
||||
version = "2.0.0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
@ -17,5 +17,7 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
|||
[profile.release]
|
||||
debug = 1
|
||||
incremental = true
|
||||
lto = "off"
|
||||
lto = "fat"
|
||||
opt-level = 3
|
||||
codegen-units = 1
|
||||
panic = "abort"
|
||||
|
|
17
Dockerfile
17
Dockerfile
|
@ -85,7 +85,7 @@ FROM pytorch-install as kernel-builder
|
|||
ARG MAX_JOBS=8
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
ninja-build \
|
||||
ninja-build cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Build Flash Attention CUDA kernels
|
||||
|
@ -160,11 +160,6 @@ WORKDIR /usr/src
|
|||
COPY server/Makefile-selective-scan Makefile
|
||||
RUN make build-all
|
||||
|
||||
# Build megablocks
|
||||
FROM kernel-builder as megablocks-builder
|
||||
|
||||
RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
|
||||
|
||||
|
@ -186,8 +181,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy conda with PyTorch and Megablocks installed
|
||||
COPY --from=megablocks-builder /opt/conda /opt/conda
|
||||
# Copy conda with PyTorch installed
|
||||
COPY --from=pytorch-install /opt/conda /opt/conda
|
||||
|
||||
# Copy build artifacts from flash attention builder
|
||||
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
|
||||
|
@ -215,7 +210,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c
|
|||
COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
|
||||
|
||||
# Install flash-attention dependencies
|
||||
# Install vllm/flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
|
@ -250,5 +245,7 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||
# Final image
|
||||
FROM base
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
|
318
LICENSE
318
LICENSE
|
@ -1,181 +1,201 @@
|
|||
Hugging Face Optimized Inference License 1.0 (HFOILv1.0)
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
This License Agreement governs the use of the Software and its Modifications. It is a
|
||||
binding agreement between the Licensor and You.
|
||||
1. Definitions.
|
||||
|
||||
This License Agreement shall be referred to as Hugging Face Optimized Inference License
|
||||
1.0 or HFOILv1.0. We may publish revised versions of this License Agreement from time to
|
||||
time. Each version will be given a distinguished number.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
By downloading, accessing, modifying, distributing or otherwise using the Software, You
|
||||
consent to all of the terms and conditions below. So, if You do not agree with those,
|
||||
please do not download, access, modify, distribute, or use the Software.
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
1. PERMISSIONS
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
You may use, modify and distribute the Software pursuant to the following terms and
|
||||
conditions:
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
Copyright License. Subject to the terms and conditions of this License Agreement and where
|
||||
and as applicable, each Contributor hereby grants You a perpetual, worldwide,
|
||||
non-exclusive, royalty-free, copyright license to reproduce, prepare, publicly display,
|
||||
publicly perform, sublicense under the terms herein, and distribute the Software and
|
||||
Modifications of the Software.
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
Patent License. Subject to the terms and conditions of this License Agreement and where
|
||||
and as applicable, each Contributor hereby grants You a perpetual, worldwide,
|
||||
non-exclusive, royalty-free patent license to make, have made, Use, offer to sell, sell,
|
||||
import, and otherwise transfer the Software, where such license applies only to those
|
||||
patent claims licensable by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s) with the Software to
|
||||
which such Contribution(s) was submitted. If You institute patent litigation against any
|
||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Software
|
||||
or a Contribution incorporated within the Software constitutes direct or contributory
|
||||
patent infringement, then any rights granted to You under this License Agreement for the
|
||||
Software shall terminate as of the date such litigation is filed.
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
No other rights. All rights not expressly granted herein are retained.
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
2. RESTRICTIONS
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
You may not distribute the Software as a hosted or managed, and paid service, where the
|
||||
service grants users access to any substantial set of the features or functionality of the
|
||||
Software. If you wish to do so, You will need to be granted additional rights from the
|
||||
Licensor which will be subject to a separate mutually agreed agreement.
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
You may not sublicense the Software under any other terms than those listed in this
|
||||
License.
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
3. OBLIGATIONS
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
When You modify the Software, You agree to: - attach a notice stating the Modifications of
|
||||
the Software You made; and - attach a notice stating that the Modifications of the
|
||||
Software are released under this License Agreement.
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
When You distribute the Software or Modifications of the Software, You agree to: - give
|
||||
any recipients of the Software a copy of this License Agreement; - retain all Explanatory
|
||||
Documentation; and if sharing the Modifications of the Software, add Explanatory
|
||||
Documentation documenting the changes made to create the Modifications of the Software; -
|
||||
retain all copyright, patent, trademark and attribution notices.
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
4. MISCELLANEOUS
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
Termination. Licensor reserves the right to restrict Use of the Software in violation of
|
||||
this License Agreement, upon which Your licenses will automatically terminate.
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
Contributions. Unless You explicitly state otherwise, any Contribution intentionally
|
||||
submitted for inclusion in the Software by You to the Licensor shall be under the terms
|
||||
and conditions of this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify the terms of any
|
||||
separate license agreement you may have executed with Licensor regarding such
|
||||
Contributions.
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
Trademarks and related. Nothing in this License Agreement permits You (i) to make Use of
|
||||
Licensors’ trademarks, trade names, or logos, (ii) otherwise suggest endorsement by
|
||||
Licensor, or (iii) misrepresent the relationship between the parties; and any rights not
|
||||
expressly granted herein are reserved by the Licensors.
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
Output You generate. Licensor claims no rights in the Output. You agree not to contravene
|
||||
any provision as stated in the License Agreement with your Use of the Output.
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
Disclaimer of Warranty. Except as expressly provided otherwise herein, and to the fullest
|
||||
extent permitted by law, Licensor provides the Software (and each Contributor provides its
|
||||
Contributions) AS IS, and Licensor disclaims all warranties or guarantees of any kind,
|
||||
express or implied, whether arising under any law or from any usage in trade, or otherwise
|
||||
including but not limited to the implied warranties of merchantability, non-infringement,
|
||||
quiet enjoyment, fitness for a particular purpose, or otherwise. You are solely
|
||||
responsible for determining the appropriateness of the Software and Modifications of the
|
||||
Software for your purposes (including your use or distribution of the Software and
|
||||
Modifications of the Software), and assume any risks associated with Your exercise of
|
||||
permissions under this License Agreement.
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
Limitation of Liability. In no event and under no legal theory, whether in tort (including
|
||||
negligence), contract, or otherwise, unless required by applicable law (such as deliberate
|
||||
and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to
|
||||
You for damages, including any direct, indirect, special, incidental, or consequential
|
||||
damages of any character arising as a result of this License Agreement or out of the Use
|
||||
or inability to Use the Software (including but not limited to damages for loss of
|
||||
goodwill, work stoppage, computer failure or malfunction, model failure or malfunction, or
|
||||
any and all other commercial damages or losses), even if such Contributor has been advised
|
||||
of the possibility of such damages.
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Accepting Warranty or Additional Liability. While sharing the Software or Modifications of
|
||||
the Software thereof, You may choose to offer and charge a fee for, acceptance of support,
|
||||
warranty, indemnity, or other liability obligations and/or rights consistent with this
|
||||
License Agreement. However, in accepting such obligations, You may act only on Your own
|
||||
behalf and on Your sole responsibility, not on behalf of Licensor or any other
|
||||
Contributor, and you hereby agree to indemnify, defend, and hold Licensor and each other
|
||||
Contributor (and their successors or assigns) harmless for any liability incurred by, or
|
||||
claims asserted against, such Licensor or Contributor (and their successors or assigns) by
|
||||
reason of your accepting any such warranty or additional liability.
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
Severability. This License Agreement is a license of copyright and patent rights and an
|
||||
agreement in contract between You and the Licensor. If any provision of this License
|
||||
Agreement is held to be invalid, illegal or unenforceable, the remaining provisions shall
|
||||
be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2022 Hugging Face
|
||||
|
||||
5. DEFINITIONS
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
“Contribution” refers to any work of authorship, including the original version of the
|
||||
Software and any Modifications of the Software that is intentionally submitted to Licensor
|
||||
for inclusion in the Software by the copyright owner or by an individual or entity
|
||||
authorized to submit on behalf of the copyright owner. For the purposes of this
|
||||
definition, “submitted” means any form of electronic, verbal, or written communication
|
||||
sent to the Licensor or its representatives, including but not limited to communication on
|
||||
electronic mailing lists, source code control systems, and issue tracking systems that are
|
||||
managed by, or on behalf of, the Licensor for the purpose of discussing and improving the
|
||||
Software, but excluding communication that is conspicuously marked or otherwise designated
|
||||
in writing by the copyright owner as “Not a Contribution.”
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
“Contributor” refers to Licensor and any individual or entity on behalf of whom a
|
||||
Contribution has been received by Licensor and subsequently incorporated within the
|
||||
Software.
|
||||
|
||||
“Data” refers to a collection of information extracted from the dataset used with the
|
||||
Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not
|
||||
licensed under this License Agreement.
|
||||
|
||||
“Explanatory Documentation” refers to any documentation or related information including
|
||||
but not limited to model cards or data cards dedicated to inform the public about the
|
||||
characteristics of the Software. Explanatory documentation is not licensed under this
|
||||
License.
|
||||
|
||||
"License Agreement" refers to these terms and conditions.
|
||||
|
||||
“Licensor” refers to the rights owners or entity authorized by the rights owners that are
|
||||
granting the terms and conditions of this License Agreement.
|
||||
|
||||
“Model” refers to machine-learning based assemblies (including checkpoints), consisting of
|
||||
learnt weights and parameters (including optimizer states), corresponding to a model
|
||||
architecture as embodied in Software source code. Source code is not licensed under this
|
||||
License Agreement.
|
||||
|
||||
“Modifications of the Software” refers to all changes to the Software, including without
|
||||
limitation derivative works of the Software.
|
||||
|
||||
“Output” refers to the results of operating the Software.
|
||||
|
||||
“Share” refers to any transmission, reproduction, publication or other sharing of the
|
||||
Software or Modifications of the Software to a third party, including providing the
|
||||
Softwaire as a hosted service made available by electronic or other remote means,
|
||||
including - but not limited to - API-based or web access.
|
||||
|
||||
“Software” refers to the software and Model (or parts of either) that Licensor makes
|
||||
available under this License Agreement.
|
||||
|
||||
“Third Parties” refers to individuals or legal entities that are not under common control
|
||||
with Licensor or You.
|
||||
|
||||
“Use” refers to anything You or your representatives do with the Software, including but
|
||||
not limited to generating any Output, fine tuning, updating, running, training, evaluating
|
||||
and/or reparametrizing the Model.
|
||||
|
||||
"You" (or "Your") refers to an individual or Legal Entity exercising permissions granted
|
||||
by this License Agreement and/or making Use of the Software for whichever purpose and in
|
||||
any field of Use.
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
|
@ -76,13 +76,13 @@ For a detailed starting guide, please see the [Quick Tour](https://huggingface.c
|
|||
model=HuggingFaceH4/zephyr-7b-beta
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
|
||||
```bash
|
||||
curl 127.0.0.1:8080/generate \
|
||||
curl 127.0.0.1:8080/generate_stream \
|
||||
-X POST \
|
||||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
|
||||
-H 'Content-Type: application/json'
|
||||
|
@ -90,7 +90,7 @@ curl 127.0.0.1:8080/generate \
|
|||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
|
@ -120,7 +120,7 @@ model=meta-llama/Llama-2-7b-chat-hf
|
|||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
token=<your cli READ token>
|
||||
|
||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 --model-id $model
|
||||
docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0 --model-id $model
|
||||
```
|
||||
|
||||
### A note on Shared Memory (shm)
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "1.4.5"
|
||||
"version": "2.0.0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
title: All TGI CLI options
|
||||
- local: basic_tutorials/non_core_models
|
||||
title: Non-core Model Serving
|
||||
- local: basic_tutorials/safety
|
||||
title: Safety
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- local: conceptual/streaming
|
||||
|
|
|
@ -60,12 +60,13 @@ Options:
|
|||
[env: QUANTIZE=]
|
||||
|
||||
Possible values:
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||
- fp8: [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations
|
||||
|
||||
```
|
||||
## SPECULATE
|
||||
|
@ -128,23 +129,29 @@ Options:
|
|||
[env: MAX_TOP_N_TOKENS=]
|
||||
[default: 5]
|
||||
|
||||
```
|
||||
## MAX_INPUT_TOKENS
|
||||
```shell
|
||||
--max-input-tokens <MAX_INPUT_TOKENS>
|
||||
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095)
|
||||
|
||||
[env: MAX_INPUT_TOKENS=]
|
||||
|
||||
```
|
||||
## MAX_INPUT_LENGTH
|
||||
```shell
|
||||
--max-input-length <MAX_INPUT_LENGTH>
|
||||
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle
|
||||
Legacy version of [`Args::max_input_tokens`]
|
||||
|
||||
[env: MAX_INPUT_LENGTH=]
|
||||
[default: 1024]
|
||||
|
||||
```
|
||||
## MAX_TOTAL_TOKENS
|
||||
```shell
|
||||
--max-total-tokens <MAX_TOTAL_TOKENS>
|
||||
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be
|
||||
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096)
|
||||
|
||||
[env: MAX_TOTAL_TOKENS=]
|
||||
[default: 2048]
|
||||
|
||||
```
|
||||
## WAITING_SERVED_RATIO
|
||||
|
@ -161,10 +168,9 @@ Options:
|
|||
## MAX_BATCH_PREFILL_TOKENS
|
||||
```shell
|
||||
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
|
||||
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent
|
||||
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to `max_input_tokens + 50` to give a bit of room
|
||||
|
||||
[env: MAX_BATCH_PREFILL_TOKENS=]
|
||||
[default: 4096]
|
||||
|
||||
```
|
||||
## MAX_BATCH_TOTAL_TOKENS
|
||||
|
@ -206,12 +212,12 @@ Options:
|
|||
[env: MAX_BATCH_SIZE=]
|
||||
|
||||
```
|
||||
## ENABLE_CUDA_GRAPHS
|
||||
## CUDA_GRAPHS
|
||||
```shell
|
||||
--enable-cuda-graphs
|
||||
Enable experimental support for cuda graphs
|
||||
--cuda-graphs <CUDA_GRAPHS>
|
||||
Specify the batch sizes to compute cuda graphs for. Use "0" to disable. Default = "1,2,4,8,16,32"
|
||||
|
||||
[env: ENABLE_CUDA_GRAPHS=]
|
||||
[env: CUDA_GRAPHS=]
|
||||
|
||||
```
|
||||
## HOSTNAME
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Model safety.
|
||||
|
||||
[Pytorch uses pickle](https://pytorch.org/docs/master/generated/torch.load.html) by default meaning that for quite a long while
|
||||
*Every* model using that format is potentially executing unintended code while purely loading the model.
|
||||
|
||||
There is a big red warning on Python's page for pickle [link](https://docs.python.org/3/library/pickle.html) but for quite a while
|
||||
this was ignored by the community. Now that AI/ML is getting used much more ubiquitously we need to switch away from this format.
|
||||
|
||||
HuggingFace is leading the effort here by creating a new format which contains pure data ([safetensors](https://github.com/huggingface/safetensors))
|
||||
and moving slowly but surely all the libs to make use of it by default.
|
||||
The move is intentionnally slow in order to make breaking changes as little impact as possible on users throughout.
|
||||
|
||||
|
||||
# TGI 2.0
|
||||
|
||||
Since the release of TGI 2.0, we take the opportunity of this major version increase to break backward compatibility for these pytorch
|
||||
models (since they are a huge security risk for anyone deploying them).
|
||||
|
||||
|
||||
From now on, TGI will not convert automatically pickle files without having `--trust-remote-code` flag or `TRUST_REMOTE_CODE=true` in the environment variables.
|
||||
This flag is already used for community defined inference code, and is therefore quite representative of the level of confidence you are giving the model providers.
|
||||
|
||||
|
||||
If you want to use a model that uses pickle, but you still do not want to trust the authors entirely we recommend making a convertion on our space made for that.
|
||||
|
||||
https://huggingface.co/spaces/safetensors/convert
|
||||
|
||||
This space will create a PR on the original model, which you are use directly regardless of merge status from the original authors. Just use
|
||||
```
|
||||
docker run .... --revision refs/pr/#ID # Or use REVISION=refs/pr/#ID in the environment
|
||||
```
|
|
@ -74,7 +74,7 @@ curl localhost:3000/generate \
|
|||
|
||||
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
||||
|
||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
|
||||
### Constrain with Pydantic
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ The following models are optimized and can be served with TGI, which uses custom
|
|||
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
|
||||
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
|
||||
- [Phi](https://huggingface.co/microsoft/phi-2)
|
||||
- [Idefics](HuggingFaceM4/idefics-9b-instruct) (Multimodal)
|
||||
- [Llava-next](llava-hf/llava-v1.6-mistral-7b-hf) (Multimodal)
|
||||
|
||||
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
|
||||
|
||||
|
|
|
@ -277,6 +277,8 @@ def launcher(event_loop):
|
|||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
|
@ -314,6 +316,12 @@ def launcher(event_loop):
|
|||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
|
@ -347,6 +355,8 @@ def launcher(event_loop):
|
|||
disable_grammar_support: bool = False,
|
||||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
|
@ -367,6 +377,12 @@ def launcher(event_loop):
|
|||
args.append(revision)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
|
@ -383,7 +399,6 @@ def launcher(event_loop):
|
|||
|
||||
env = {
|
||||
"LOG_LEVEL": "info,text_generation_router=debug",
|
||||
"ENABLE_CUDA_GRAPHS": "true",
|
||||
}
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "stop_sequence",
|
||||
"generated_tokens": 6,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -10.5,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -12.140625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.0654297,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 1014,
|
||||
"logprob": -2.7460938,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 6032,
|
||||
"logprob": -1.359375,
|
||||
"special": false,
|
||||
"text": " purpose"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 456,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 1369,
|
||||
"logprob": -0.40063477,
|
||||
"special": false,
|
||||
"text": " test"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request\nThe purpose of this test"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,73 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.00756073,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.20117188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 16114,
|
||||
"logprob": -1.2597656,
|
||||
"special": false,
|
||||
"text": "Once"
|
||||
},
|
||||
{
|
||||
"id": 3714,
|
||||
"logprob": -0.20825195,
|
||||
"special": false,
|
||||
"text": " upon"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.00178051,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 727,
|
||||
"logprob": -0.011955261,
|
||||
"special": false,
|
||||
"text": " time"
|
||||
},
|
||||
{
|
||||
"id": 28725,
|
||||
"logprob": -0.17541504,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 736,
|
||||
"logprob": -0.91308594,
|
||||
"special": false,
|
||||
"text": " there"
|
||||
},
|
||||
{
|
||||
"id": 403,
|
||||
"logprob": -0.058410645,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.009689331,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nOnce upon a time, there was a"
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 9,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
|
@ -14,7 +14,7 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 16017,
|
||||
"logprob": -0.30908203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " blue"
|
||||
},
|
||||
|
@ -26,39 +26,45 @@
|
|||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -0.28271484,
|
||||
"logprob": -0.4716797,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 15484,
|
||||
"logprob": -1.7929688,
|
||||
"id": 261,
|
||||
"logprob": -0.044677734,
|
||||
"special": false,
|
||||
"text": "appear"
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 345,
|
||||
"logprob": -0.8935547,
|
||||
"id": 35622,
|
||||
"logprob": -0.79589844,
|
||||
"special": false,
|
||||
"text": "ed"
|
||||
"text": " cloud"
|
||||
},
|
||||
{
|
||||
"id": 281,
|
||||
"id": 263,
|
||||
"logprob": -1.2958984,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 305,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"id": 35622,
|
||||
"logprob": -1.1630859,
|
||||
"special": false,
|
||||
"text": " cloud"
|
||||
},
|
||||
{
|
||||
"id": 263,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 20495,
|
||||
"logprob": -0.32299805,
|
||||
"special": false,
|
||||
"text": " sky"
|
||||
"text": "s"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
|
@ -66,7 +72,8 @@
|
|||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
|
||||
"generated_text": "Why is the sky blue?blue sky, clouds and clouds"
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 60,
|
||||
|
|
|
@ -31,7 +31,7 @@
|
|||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
|
|
|
@ -31,7 +31,7 @@
|
|||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native",
|
||||
"system_fingerprint": "2.0.0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 21,
|
||||
"prompt_tokens": 187,
|
||||
|
|
|
@ -23,5 +23,5 @@
|
|||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.5-native"
|
||||
"system_fingerprint": "2.0.0-native"
|
||||
}
|
||||
|
|
|
@ -33,6 +33,9 @@ async def test_idefics(idefics, response_snapshot):
|
|||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text == " \nAssistant: A rooster stands"
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
@ -48,6 +51,9 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
|
|||
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert (
|
||||
generated_texts[0] == " \nAssistant: A rooster stands"
|
||||
), f"{response.generated_text}"
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llava_next_handle(launcher):
|
||||
with launcher(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
num_shard=4,
|
||||
max_input_length=4000,
|
||||
max_total_tokens=4096,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llava_next(flash_llava_next_handle):
|
||||
await flash_llava_next_handle.health(300)
|
||||
return flash_llava_next_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
response = await flash_llava_next.generate(
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
)
|
||||
assert (
|
||||
response.generated_text == "\n\nOnce upon a time, there was a"
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
|
||||
response = await flash_llava_next.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 6
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llava_next_load(
|
||||
flash_llava_next, generate_load, response_snapshot
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_llava_next,
|
||||
f"User:![]({chicken})Can you tell me a very short story based on the image?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
assert generated_texts[0] == "\n\nOnce upon a time, there was a"
|
||||
assert len(generated_texts) == 4
|
||||
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -45,7 +45,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
|
|||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 9
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def t5_sharded_handle(launcher):
|
||||
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
|
||||
with launcher("google/flan-t5-xxl", num_shard=4) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation-integration-tests"
|
||||
version = "1.4.5"
|
||||
version = "2.0.0"
|
||||
description = "Text Generation Inference integration tests"
|
||||
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
|
||||
|
||||
|
|
|
@ -9,8 +9,10 @@ homepage.workspace = true
|
|||
[dependencies]
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
ctrlc = { version = "3.4.1", features = ["termination"] }
|
||||
hf-hub = "0.3.2"
|
||||
nix = { version = "0.28.0", features = ["signal"] }
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
once_cell = "1.19.0"
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
serde_json = "1.0.107"
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use serde::Deserialize;
|
||||
|
@ -19,17 +20,23 @@ use tracing_subscriber::EnvFilter;
|
|||
|
||||
mod env_runtime;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
max_seq_len: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization {
|
||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||
/// https://hf.co/models?search=awq.
|
||||
/// <https://hf.co/models?search=awq>.
|
||||
/// Should replace GPTQ models wherever possible because of the better latency
|
||||
Awq,
|
||||
/// 8 bit quantization, doesn't require specific model.
|
||||
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
||||
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
Eetq,
|
||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
|
||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
|
||||
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
|
||||
/// triton kernel (wider support) when it's not.
|
||||
/// AWQ has faster kernels.
|
||||
|
@ -47,6 +54,11 @@ enum Quantization {
|
|||
/// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better
|
||||
/// perplexity performance for you model
|
||||
BitsandbytesFP4,
|
||||
/// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above
|
||||
/// This dtype has native ops should be the fastest if available.
|
||||
/// This is currently not the fastest because of local unpacking + padding to satisfy matrix
|
||||
/// multiplication limitations.
|
||||
Fp8,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Quantization {
|
||||
|
@ -73,6 +85,9 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Eetq => {
|
||||
write!(f, "eetq")
|
||||
}
|
||||
Quantization::Fp8 => {
|
||||
write!(f, "fp8")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -206,8 +221,13 @@ struct Args {
|
|||
/// for users. The larger this value, the longer prompt users can send which
|
||||
/// can impact the overall memory required to handle the load.
|
||||
/// Please note that some models have a finite range of sequence they can handle.
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
/// Default to min(max_position_embeddings - 1, 4095)
|
||||
#[clap(long, env)]
|
||||
max_input_tokens: Option<usize>,
|
||||
|
||||
/// Legacy version of [`Args::max_input_tokens`].
|
||||
#[clap(long, env)]
|
||||
max_input_length: Option<usize>,
|
||||
|
||||
/// This is the most important value to set as it defines the "memory budget"
|
||||
/// of running clients requests.
|
||||
|
@ -217,8 +237,9 @@ struct Args {
|
|||
/// `1511` max_new_tokens.
|
||||
/// The larger this value, the larger amount each request will be in your RAM
|
||||
/// and the less effective batching can be.
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
/// Default to min(max_position_embeddings, 4096)
|
||||
#[clap(long, env)]
|
||||
max_total_tokens: Option<usize>,
|
||||
|
||||
/// This represents the ratio of waiting queries vs running queries where
|
||||
/// you want to start considering pausing the running queries to include the waiting
|
||||
|
@ -236,8 +257,9 @@ struct Args {
|
|||
/// Limits the number of tokens for the prefill operation.
|
||||
/// Since this operation take the most memory and is compute bound, it is interesting
|
||||
/// to limit the number of requests that can be sent.
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
/// Default to `max_input_tokens + 50` to give a bit of room.
|
||||
#[clap(long, env)]
|
||||
max_batch_prefill_tokens: Option<u32>,
|
||||
|
||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||
/// of the available hardware.
|
||||
|
@ -284,9 +306,11 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
|
||||
/// Enable experimental support for cuda graphs
|
||||
#[clap(long, env)]
|
||||
enable_cuda_graphs: bool,
|
||||
/// Specify the batch sizes to compute cuda graphs for.
|
||||
/// Use "0" to disable.
|
||||
/// Default = "1,2,4,8,16,32"
|
||||
#[clap(long, env, value_delimiter = ',')]
|
||||
cuda_graphs: Option<Vec<usize>>,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
|
@ -416,7 +440,7 @@ fn shard_manager(
|
|||
disable_custom_kernels: bool,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
enable_cuda_graphs: bool,
|
||||
cuda_graphs: Vec<usize>,
|
||||
cuda_memory_fraction: f32,
|
||||
rope_scaling: Option<RopeScaling>,
|
||||
rope_factor: Option<f32>,
|
||||
|
@ -493,6 +517,9 @@ fn shard_manager(
|
|||
// Copy current process env
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Remove LOG_LEVEL if present
|
||||
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
||||
|
||||
// Torch Distributed Env vars
|
||||
envs.push(("RANK".into(), rank.to_string().into()));
|
||||
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
|
||||
|
@ -549,8 +576,16 @@ fn shard_manager(
|
|||
};
|
||||
|
||||
// Enable experimental support for cuda graphs
|
||||
if enable_cuda_graphs {
|
||||
envs.push(("ENABLE_CUDA_GRAPHS".into(), "True".into()))
|
||||
if !cuda_graphs.is_empty() {
|
||||
envs.push((
|
||||
"CUDA_GRAPHS".into(),
|
||||
cuda_graphs
|
||||
.into_iter()
|
||||
.map(|c| c.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
// If disable_custom_kernels is true, pass it to the shard as an env var
|
||||
|
@ -572,6 +607,7 @@ fn shard_manager(
|
|||
tracing::info!("Starting shard");
|
||||
let mut p = match Command::new("text-generation-server")
|
||||
.args(shard_args)
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
|
@ -782,6 +818,14 @@ enum LauncherError {
|
|||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for LauncherError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for LauncherError {}
|
||||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
// Enter download tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
|
@ -810,6 +854,9 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||
// Copy current process env
|
||||
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
// Remove LOG_LEVEL if present
|
||||
envs.retain(|(name, _)| name != "LOG_LEVEL");
|
||||
|
||||
// Disable progress bar
|
||||
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));
|
||||
|
||||
|
@ -844,6 +891,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||
tracing::info!("Starting download process.");
|
||||
let mut download_process = match Command::new("text-generation-server")
|
||||
.args(download_args)
|
||||
.env_clear()
|
||||
.envs(envs)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
|
@ -914,6 +962,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||
fn spawn_shards(
|
||||
num_shard: usize,
|
||||
args: &Args,
|
||||
cuda_graphs: Vec<usize>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
shutdown_sender: mpsc::Sender<()>,
|
||||
|
@ -941,7 +990,7 @@ fn spawn_shards(
|
|||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let enable_cuda_graphs = args.enable_cuda_graphs;
|
||||
let cuda_graphs_clone = cuda_graphs.clone();
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
|
@ -963,7 +1012,7 @@ fn spawn_shards(
|
|||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
enable_cuda_graphs,
|
||||
cuda_graphs_clone,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
|
@ -1019,6 +1068,9 @@ fn compute_type(num_shard: usize) -> Option<String> {
|
|||
fn spawn_webserver(
|
||||
num_shard: usize,
|
||||
args: Args,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_prefill_tokens: u32,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
) -> Result<Child, LauncherError> {
|
||||
|
@ -1034,12 +1086,12 @@ fn spawn_webserver(
|
|||
args.max_stop_sequences.to_string(),
|
||||
"--max-top-n-tokens".to_string(),
|
||||
args.max_top_n_tokens.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
args.max_input_length.to_string(),
|
||||
"--max-input-tokens".to_string(),
|
||||
max_input_tokens.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
args.max_total_tokens.to_string(),
|
||||
max_total_tokens.to_string(),
|
||||
"--max-batch-prefill-tokens".to_string(),
|
||||
args.max_batch_prefill_tokens.to_string(),
|
||||
max_batch_prefill_tokens.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
|
@ -1217,19 +1269,129 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
let api = Api::new()?;
|
||||
let repo = if let Some(ref revision) = args.revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: Config = serde_json::from_str(&content)?;
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(LauncherError::ArgumentValidation(
|
||||
"no max defined".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(max_position_embeddings)
|
||||
};
|
||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||
|
||||
let max_input_tokens = {
|
||||
match (args.max_input_tokens, args.max_input_length) {
|
||||
(Some(max_input_tokens), Some(max_input_length)) => {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
||||
)));
|
||||
}
|
||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
||||
(None, None) => {
|
||||
let value = max_position_embeddings - 1;
|
||||
tracing::info!("Default `max_input_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_total_tokens = {
|
||||
match args.max_total_tokens {
|
||||
Some(max_total_tokens) => max_total_tokens,
|
||||
None => {
|
||||
let value = max_position_embeddings;
|
||||
tracing::info!("Default `max_total_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_batch_prefill_tokens = {
|
||||
match args.max_batch_prefill_tokens {
|
||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||
None => {
|
||||
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
||||
max_batch_size * max_input_tokens
|
||||
} else {
|
||||
// Adding some edge in order to account for potential block_size alignement
|
||||
// issue.
|
||||
max_input_tokens + 50
|
||||
} as u32;
|
||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Validate args
|
||||
if args.max_input_length >= args.max_total_tokens {
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if args.max_input_length as u32 > args.max_batch_prefill_tokens {
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, args.max_input_length
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_input_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
Some(
|
||||
Quantization::Bitsandbytes
|
||||
| Quantization::BitsandbytesNF4
|
||||
| Quantization::BitsandbytesFP4,
|
||||
),
|
||||
) => {
|
||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
_ => {
|
||||
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
|
||||
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
|
||||
cuda_graphs
|
||||
}
|
||||
};
|
||||
|
||||
if args.validation_workers == 0 {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
|
@ -1248,16 +1410,16 @@ fn main() -> Result<(), LauncherError> {
|
|||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, max_batch_total_tokens
|
||||
max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, max_batch_total_tokens
|
||||
max_total_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
@ -1304,6 +1466,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
spawn_shards(
|
||||
num_shard,
|
||||
&args,
|
||||
cuda_graphs,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
|
@ -1318,11 +1481,19 @@ fn main() -> Result<(), LauncherError> {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
|
||||
.map_err(|err| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
})?;
|
||||
let mut webserver = spawn_webserver(
|
||||
num_shard,
|
||||
args,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
)
|
||||
.map_err(|err| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
})?;
|
||||
|
||||
// Default exit code
|
||||
let mut exit_code = Ok(());
|
||||
|
|
|
@ -44,10 +44,12 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
|
|||
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
|
||||
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", branch = "main", commit = "5cd4efb" }
|
||||
minijinja = { git = "https://github.com/mitsuhiko/minijinja.git", rev = "5cd4efb" }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = "0.22.0"
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
|
||||
|
|
|
@ -112,10 +112,15 @@ impl Client {
|
|||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str("![](data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=");
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
inputs: "_test ".to_string().repeat(max_input_length as usize),
|
||||
inputs,
|
||||
truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct LlavaNext {
|
||||
text_config: TextConfig,
|
||||
vision_config: VisionConfig,
|
||||
image_grid_pinpoints: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
fn get_anyres_image_grid_shape(
|
||||
height: usize,
|
||||
width: usize,
|
||||
grid_pinpoints: &[(usize, usize)],
|
||||
patch_size: usize,
|
||||
) -> (usize, usize) {
|
||||
let (height, width) = select_best_resolution(height, width, grid_pinpoints);
|
||||
(height / patch_size, width / patch_size)
|
||||
}
|
||||
|
||||
/// Selects the best resolution from a list of possible resolutions based on the original size.
|
||||
/// This is done by calculating the effective and wasted resolution for each possible resolution.
|
||||
/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
|
||||
fn select_best_resolution(
|
||||
original_height: usize,
|
||||
original_width: usize,
|
||||
possible_resolutions: &[(usize, usize)],
|
||||
) -> (usize, usize) {
|
||||
let mut best_fit = None;
|
||||
let mut max_effective_resolution = 0;
|
||||
let mut min_wasted_resolution = f32::NEG_INFINITY;
|
||||
|
||||
for (height, width) in possible_resolutions {
|
||||
let wscale = *width as f32 / original_width as f32;
|
||||
let hscale = *height as f32 / original_height as f32;
|
||||
// f32 partial ord.
|
||||
let scale = if wscale > hscale { hscale } else { wscale };
|
||||
let downscaled_width = (*width as f32 * scale) as usize;
|
||||
let downscaled_height = (*height as f32 * scale) as usize;
|
||||
let effective_resolution = std::cmp::min(
|
||||
downscaled_width * downscaled_height,
|
||||
original_width * original_height,
|
||||
);
|
||||
let wasted_resolution = (width * height) - effective_resolution;
|
||||
|
||||
if effective_resolution > max_effective_resolution
|
||||
|| (effective_resolution == max_effective_resolution
|
||||
&& (wasted_resolution as f32) < min_wasted_resolution)
|
||||
{
|
||||
max_effective_resolution = effective_resolution;
|
||||
min_wasted_resolution = wasted_resolution as f32;
|
||||
best_fit = Some((*height, *width));
|
||||
}
|
||||
}
|
||||
|
||||
best_fit.unwrap_or((original_height, original_width))
|
||||
}
|
||||
|
||||
impl LlavaNext {
|
||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||
let image_size = self.vision_config.image_size;
|
||||
let patch_size = self.vision_config.patch_size;
|
||||
assert!(image_size % patch_size == 0);
|
||||
let npatches = image_size / patch_size;
|
||||
let (num_patch_height, num_patch_width) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
// Ceil
|
||||
let height_of_patch = (height * npatches + width - 1) / width;
|
||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
||||
// They are only added after width
|
||||
let newline_features = height_of_patch * num_patch_width;
|
||||
// The base patch covers the entire image
|
||||
let base_features = npatches.pow(2);
|
||||
unpadded_features + newline_features + base_features
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct ClipVisionModel {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Config {
|
||||
LlavaNext(LlavaNext),
|
||||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
Santacoder,
|
||||
Bloom,
|
||||
Mpt,
|
||||
GptNeox,
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
PhiMsft,
|
||||
Llama,
|
||||
Baichuan,
|
||||
Gemma,
|
||||
Cohere,
|
||||
Drbx,
|
||||
Falcon,
|
||||
Mixtral,
|
||||
Starcoder2,
|
||||
Qwen2,
|
||||
Opt,
|
||||
T5,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct TextConfig {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct VisionConfig {
|
||||
image_size: usize,
|
||||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_llava_next_features() {
|
||||
let config = LlavaNext {
|
||||
text_config: TextConfig {},
|
||||
vision_config: VisionConfig {
|
||||
image_size: 336,
|
||||
patch_size: 14,
|
||||
},
|
||||
image_grid_pinpoints: vec![
|
||||
(336, 672),
|
||||
(672, 336),
|
||||
(672, 672),
|
||||
(1008, 336),
|
||||
(336, 1008),
|
||||
],
|
||||
};
|
||||
|
||||
let slots = config.get_number_of_features(640, 640);
|
||||
assert_eq!(slots, 2928);
|
||||
let slots = config.get_number_of_features(480, 640);
|
||||
assert_eq!(slots, 2340);
|
||||
let slots = config.get_number_of_features(899, 1024);
|
||||
assert_eq!(slots, 2732);
|
||||
let slots = config.get_number_of_features(1024, 899);
|
||||
assert_eq!(slots, 3320);
|
||||
}
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::{
|
||||
ChatTemplateInputs, Entry, GenerateRequest, GenerateStreamResponse, HubTokenizerConfig,
|
||||
Message, PrefillToken, Queue, Token,
|
||||
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
|
||||
HubTokenizerConfig, Message, PrefillToken, Queue, Token,
|
||||
};
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
|
@ -86,7 +86,18 @@ impl Infer {
|
|||
|
||||
let chat_template = tokenizer_config
|
||||
.chat_template
|
||||
.map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));
|
||||
.and_then(|t| match t {
|
||||
ChatTemplateVersions::Single(template) => Some(template),
|
||||
ChatTemplateVersions::Multiple(templates) => templates
|
||||
.into_iter()
|
||||
.find(|t| t.name == "default")
|
||||
.map(|t| t.template),
|
||||
})
|
||||
.map(|t| {
|
||||
// .strip() is not supported in minijinja
|
||||
let t = t.replace(".strip()", " | trim");
|
||||
ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
|
||||
});
|
||||
|
||||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
|
@ -1099,7 +1110,7 @@ mod tests {
|
|||
ChatTemplateTestItem {
|
||||
name: "_base",
|
||||
chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
|
@ -1110,7 +1121,7 @@ mod tests {
|
|||
ChatTemplateTestItem {
|
||||
name: "blenderbot",
|
||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
|
@ -1121,7 +1132,7 @@ mod tests {
|
|||
ChatTemplateTestItem {
|
||||
name: "blenderbot_small",
|
||||
chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
|
@ -1132,7 +1143,7 @@ mod tests {
|
|||
ChatTemplateTestItem {
|
||||
name: "bloom",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
|
@ -1143,7 +1154,7 @@ mod tests {
|
|||
ChatTemplateTestItem {
|
||||
name: "gpt_neox",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
|
@ -1154,38 +1165,37 @@ mod tests {
|
|||
ChatTemplateTestItem {
|
||||
name: "gpt2",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "llama",
|
||||
// NOTE: the `.strip()` has been replaced with `| trim` in the following template
|
||||
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
|
||||
target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "whisper",
|
||||
chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
|
||||
input: ChatTemplateInputs{
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: true,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("<|endoftext|>"),
|
||||
},
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>"
|
||||
}
|
||||
|
||||
target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
|
||||
},
|
||||
];
|
||||
|
||||
#[allow(unused_variables)] // name is unused
|
||||
|
@ -1211,7 +1221,7 @@ mod tests {
|
|||
messages: example_chat_with_system.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some(""),
|
||||
eos_token: Some("</s>")
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
|
@ -1237,7 +1247,7 @@ mod tests {
|
|||
bos_token: Some(""),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>"
|
||||
target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
|
||||
|
@ -1259,7 +1269,7 @@ mod tests {
|
|||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
|
||||
target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
|
@ -1276,7 +1286,7 @@ mod tests {
|
|||
name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
|
||||
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||
input: ChatTemplateInputs {
|
||||
messages: example_chat.clone(),
|
||||
messages: example_chat.clone(),
|
||||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
|
@ -1360,7 +1370,7 @@ mod tests {
|
|||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>"
|
||||
target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
|
||||
},
|
||||
ChatTemplateTestItem {
|
||||
name: "internlm/internlm2-chat-7b",
|
||||
|
@ -1443,7 +1453,7 @@ mod tests {
|
|||
eos_token: Some("</s>"),
|
||||
},
|
||||
target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
#[allow(unused_variables)] // name is unused
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
pub mod config;
|
||||
mod health;
|
||||
/// Text Generation Inference Webserver
|
||||
mod infer;
|
||||
|
@ -48,9 +49,22 @@ pub struct HubModelInfo {
|
|||
pub pipeline_tag: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub struct ChatTemplate {
|
||||
name: String,
|
||||
template: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatTemplateVersions {
|
||||
Single(String),
|
||||
Multiple(Vec<ChatTemplate>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct HubTokenizerConfig {
|
||||
pub chat_template: Option<String>,
|
||||
pub chat_template: Option<ChatTemplateVersions>,
|
||||
pub completion_template: Option<String>,
|
||||
#[serde(deserialize_with = "token_serde::deserialize")]
|
||||
pub bos_token: Option<String>,
|
||||
|
@ -977,7 +991,10 @@ mod tests {
|
|||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||
|
||||
// check that we successfully parsed the tokens
|
||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||
assert_eq!(
|
||||
config.chat_template,
|
||||
Some(ChatTemplateVersions::Single("test".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
|
@ -1009,7 +1026,10 @@ mod tests {
|
|||
let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();
|
||||
|
||||
// check that we successfully parsed the tokens
|
||||
assert_eq!(config.chat_template, Some("test".to_string()));
|
||||
assert_eq!(
|
||||
config.chat_template,
|
||||
Some(ChatTemplateVersions::Single("test".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
config.bos_token,
|
||||
Some("<|begin▁of▁sentence|>".to_string())
|
||||
|
|
|
@ -13,6 +13,7 @@ use std::io::BufReader;
|
|||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::config::Config;
|
||||
use text_generation_router::{server, HubModelInfo, HubTokenizerConfig};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
|
@ -34,7 +35,7 @@ struct Args {
|
|||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
|
@ -89,7 +90,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
|
@ -117,13 +118,13 @@ async fn main() -> Result<(), RouterError> {
|
|||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_length >= max_total_tokens {
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
|
@ -191,15 +192,19 @@ async fn main() -> Result<(), RouterError> {
|
|||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (tokenizer, model_info) = if local_model {
|
||||
let (tokenizer, model_info, config) = if local_model {
|
||||
let tokenizer = Tokenizer::from_file(local_path.join("tokenizer.json")).ok();
|
||||
let model_info = HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
};
|
||||
let config: Option<Config> = std::fs::read_to_string(local_path.join("config.json"))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| serde_json::from_str(c).ok());
|
||||
|
||||
(tokenizer, model_info)
|
||||
(tokenizer, model_info, config)
|
||||
} else if let Some(api) = api.clone() {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
|
@ -212,6 +217,19 @@ async fn main() -> Result<(), RouterError> {
|
|||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
|
||||
let config: Option<Config> = api_repo.get("config.json").await.ok().and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
|
||||
let model_info = get_model_info(&api_repo).await.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
HubModelInfo {
|
||||
|
@ -221,7 +239,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
}
|
||||
});
|
||||
|
||||
(tokenizer, model_info)
|
||||
(tokenizer, model_info, config)
|
||||
} else {
|
||||
// No API and no local model
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
|
@ -229,6 +247,8 @@ async fn main() -> Result<(), RouterError> {
|
|||
));
|
||||
};
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
|
||||
// Load tokenizer config if found locally, or check if we can get it from the API if needed
|
||||
let tokenizer_config = if let Some(path) = tokenizer_config_path {
|
||||
tracing::info!("Using local tokenizer config from user specified path");
|
||||
|
@ -291,7 +311,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
tracing::info!("Warming up model");
|
||||
let max_supported_batch_total_tokens = match sharded_client
|
||||
.warmup(
|
||||
max_input_length as u32,
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
|
@ -354,7 +374,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
|
@ -363,6 +383,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_batch_size,
|
||||
sharded_client,
|
||||
tokenizer,
|
||||
config,
|
||||
validation_workers,
|
||||
addr,
|
||||
cors_allow_origin,
|
||||
|
@ -381,12 +402,15 @@ async fn main() -> Result<(), RouterError> {
|
|||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// STDOUT/STDERR layer
|
||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_ansi(ansi)
|
||||
.with_line_number(true);
|
||||
|
||||
let fmt_layer = match json_output {
|
||||
|
|
|
@ -190,16 +190,22 @@ impl State {
|
|||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
tracing::debug!("No queue");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if we have enough entries
|
||||
if let Some(min_size) = min_size {
|
||||
if self.entries.len() < min_size {
|
||||
tracing::debug!("Not enough entries");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Pad prefill_token_budget to be a multiple of block size
|
||||
let prefill_token_budget =
|
||||
((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||
next_batch_span.follows_from(&Span::current());
|
||||
|
@ -218,6 +224,7 @@ impl State {
|
|||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -254,10 +261,12 @@ impl State {
|
|||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
|
@ -288,6 +297,7 @@ impl State {
|
|||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::config::Config;
|
||||
/// HTTP Server logic
|
||||
use crate::health::Health;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
|
@ -164,7 +165,8 @@ async fn generate(
|
|||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
||||
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
|
@ -1154,6 +1156,7 @@ pub async fn run(
|
|||
max_batch_size: Option<usize>,
|
||||
client: ShardedClient,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
|
@ -1236,6 +1239,7 @@ pub async fn run(
|
|||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
use crate::config::Config;
|
||||
/// Payload validation logic
|
||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use text_generation_client::{
|
||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokenizers::TruncationDirection;
|
||||
// use tokenizers::TruncationDirection;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{io::Reader as ImageReader, ImageFormat};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{instrument, Span};
|
||||
|
@ -34,6 +38,7 @@ impl Validation {
|
|||
pub(crate) fn new(
|
||||
workers: usize,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
|
@ -50,12 +55,13 @@ impl Validation {
|
|||
// Create workers
|
||||
for _ in 0..workers {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
let config_clone = config.clone();
|
||||
let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();
|
||||
senders.push(tokenizer_sender);
|
||||
|
||||
// Spawn worker
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokenizer_worker(tokenizer_clone, tokenizer_receiver)
|
||||
tokenizer_worker(tokenizer_clone, config_clone, tokenizer_receiver)
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -155,14 +161,17 @@ impl Validation {
|
|||
} else {
|
||||
return Err(ValidationError::UnsetMaxNewTokens);
|
||||
};
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
let mut input_length = truncate.unwrap_or(self.max_input_length);
|
||||
|
||||
// We don't have a tokenizer, therefore we have no idea how long is the query, let
|
||||
// them through and hope for the best.
|
||||
// Validate MaxNewTokens
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
return Err(ValidationError::MaxNewTokens(
|
||||
self.max_total_tokens - self.max_input_length,
|
||||
max_new_tokens,
|
||||
));
|
||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
// return Err(ValidationError::MaxNewTokens(
|
||||
// self.max_total_tokens - self.max_input_length,
|
||||
// max_new_tokens,
|
||||
// ));
|
||||
}
|
||||
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
|
@ -408,48 +417,137 @@ async fn round_robin_task(
|
|||
}
|
||||
|
||||
/// Start tokenization workers
|
||||
fn tokenizer_worker(tokenizer: Tokenizer, mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>) {
|
||||
fn tokenizer_worker(
|
||||
tokenizer: Tokenizer,
|
||||
config: Option<Config>,
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
let is_multimodal = {
|
||||
let vocab = tokenizer.get_vocab(true);
|
||||
vocab.contains_key("<image>")
|
||||
};
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, is_multimodal))
|
||||
.send(prepare_input(inputs, truncate, &tokenizer, &config))
|
||||
.unwrap_or(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {
|
||||
match mimetype {
|
||||
"image/png" => Some(ImageFormat::Png),
|
||||
"image/jpeg" => Some(ImageFormat::Jpeg),
|
||||
"image/jpg" => Some(ImageFormat::Jpeg),
|
||||
"image/gif" => Some(ImageFormat::Gif),
|
||||
"image/webp" => Some(ImageFormat::WebP),
|
||||
"image/tiff" => Some(ImageFormat::Tiff),
|
||||
// "image/pnm"=>Some(ImageFormat::Pnm),
|
||||
// "image/tga"=>Some(ImageFormat::Tga),
|
||||
// "image/dds"=>Some(ImageFormat::Dds),
|
||||
// "image/bmp"=>Some(ImageFormat::Bmp),
|
||||
// "image/ico"=>Some(ImageFormat::Ico),
|
||||
// "image/x-exr"=>Some(ImageFormat::OpenExr),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
fn format_to_mimetype(format: ImageFormat) -> String {
|
||||
match format {
|
||||
ImageFormat::Png => "image/png",
|
||||
ImageFormat::Jpeg => "image/jpeg",
|
||||
ImageFormat::Gif => "image/gif",
|
||||
ImageFormat::WebP => "image/webp",
|
||||
ImageFormat::Tiff => "image/tiff",
|
||||
_ => "application/octet-stream",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn fetch_image(input: &str) -> Result<(String, usize, usize), ValidationError> {
|
||||
if input.starts_with("![](http://") || input.starts_with("![](https://") {
|
||||
let url = &input["![](".len()..input.len() - 1];
|
||||
let data = reqwest::blocking::get(url)?.bytes()?;
|
||||
|
||||
let format = image::guess_format(&data)?;
|
||||
// TODO Remove this clone
|
||||
let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;
|
||||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
let mimetype = format_to_mimetype(format);
|
||||
let encoded = STANDARD.encode(data);
|
||||
let data_uri = format!("![](data:{mimetype};base64,{encoded})");
|
||||
Ok((data_uri, height, width))
|
||||
} else if input.starts_with("![](data:") {
|
||||
// Remove ![](....)
|
||||
let content = &input["![](data:".len()..input.len() - 1];
|
||||
let tokens: Vec<_> = content.split(';').collect();
|
||||
if tokens.len() != 2 {
|
||||
return Err(ValidationError::InvalidImageContent(content.to_string()));
|
||||
}
|
||||
let mimetype = tokens[0];
|
||||
let content = tokens[1];
|
||||
|
||||
if !content.starts_with("base64,") {
|
||||
return Err(ValidationError::InvalidImageContent(content.to_string()));
|
||||
}
|
||||
|
||||
let data = STANDARD.decode(content["base64,".len()..].as_bytes())?;
|
||||
let img = if let Some(format) = format_from_mimetype(mimetype) {
|
||||
ImageReader::with_format(Cursor::new(data), format).decode()?
|
||||
} else {
|
||||
ImageReader::new(Cursor::new(data))
|
||||
.with_guessed_format()
|
||||
.map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?
|
||||
.decode()?
|
||||
};
|
||||
|
||||
let height: usize = img.height().try_into()?;
|
||||
let width: usize = img.width().try_into()?;
|
||||
Ok((input.to_string(), height, width))
|
||||
} else {
|
||||
Err(ValidationError::InvalidImageContent(input.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get input length and optionally truncate it
|
||||
fn prepare_input(
|
||||
mut inputs: String,
|
||||
truncate: Option<usize>,
|
||||
_truncate: Option<usize>,
|
||||
tokenizer: &Tokenizer,
|
||||
is_multimodal: bool,
|
||||
config: &Option<Config>,
|
||||
) -> Result<(tokenizers::Encoding, String), ValidationError> {
|
||||
let simplified_query = if is_multimodal {
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
RE.replace_all(&inputs, "<image>").into()
|
||||
} else {
|
||||
inputs.clone()
|
||||
};
|
||||
// Get the number of tokens in the input
|
||||
let mut encoding = tokenizer
|
||||
.encode(simplified_query, true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
// Optionally truncate
|
||||
if let Some(truncate) = truncate {
|
||||
if truncate < encoding.len() && !is_multimodal {
|
||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||
inputs = tokenizer
|
||||
.decode(encoding.get_ids(), false)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let tokenizer_query = match config {
|
||||
Some(Config::LlavaNext(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
}
|
||||
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
|
||||
_ => inputs.clone(),
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
let encoding = tokenizer
|
||||
.encode(tokenizer_query, true)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
Ok((encoding, inputs))
|
||||
}
|
||||
|
@ -523,6 +621,16 @@ pub enum ValidationError {
|
|||
Grammar,
|
||||
#[error("grammar is not valid: {0}")]
|
||||
InvalidGrammar(String),
|
||||
#[error("base64 encoding is invalid: {0}")]
|
||||
InvalidBase64(#[from] base64::DecodeError),
|
||||
#[error("invalid image: {0}")]
|
||||
InvalidImage(#[from] image::ImageError),
|
||||
#[error("invalid integer: {0}")]
|
||||
InvalidInt(#[from] core::num::TryFromIntError),
|
||||
#[error("invalid image content: {0}")]
|
||||
InvalidImageContent(String),
|
||||
#[error("Could not fetch image: {0}")]
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -541,9 +649,11 @@ mod tests {
|
|||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -557,8 +667,9 @@ mod tests {
|
|||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
Ok((_s, 0, 10)) => (),
|
||||
r => panic!("Unexpected not max new tokens: {r:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -572,9 +683,11 @@ mod tests {
|
|||
let max_total_tokens = 6;
|
||||
let disable_grammar_support = true;
|
||||
let workers = 1;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -603,9 +716,11 @@ mod tests {
|
|||
let max_total_tokens = 6;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -639,9 +754,11 @@ mod tests {
|
|||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequence,
|
||||
max_top_n_tokens,
|
||||
|
@ -704,9 +821,11 @@ mod tests {
|
|||
let max_total_tokens = 106;
|
||||
let workers = 1;
|
||||
let disable_grammar_support = true;
|
||||
let config = None;
|
||||
let validation = Validation::new(
|
||||
workers,
|
||||
tokenizer,
|
||||
config,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
|
|
|
@ -17,9 +17,6 @@ gen-server:
|
|||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation_server/pb/__init__.py
|
||||
|
||||
install-megablocks:
|
||||
pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
|
||||
|
||||
install: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements_cuda.txt
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
eetq_commit := 71adb5e191bb8290069a580abff0355d7b2dd5c9
|
||||
eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0
|
||||
|
||||
eetq:
|
||||
# Clone eetq
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
flash_att_v2_commit_cuda := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
||||
flash_att_v2_commit_rocm := 8736558c287ff2ef28b24878e42828c595ac3e69
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
vllm-cuda:
|
||||
# Clone vllm
|
||||
pip install -U ninja packaging --no-cache-dir
|
||||
git clone https://github.com/vllm-project/vllm.git vllm
|
||||
git clone https://github.com/OlivierDehaene/vllm.git vllm
|
||||
|
||||
build-vllm-cuda: vllm-cuda
|
||||
cd vllm && git fetch && git checkout f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
cd vllm && git fetch && git checkout 4bec8cee87f6bb8cebaec297029713cd2082e0b2
|
||||
cd vllm && python setup.py build
|
||||
|
||||
install-vllm-cuda: build-vllm-cuda
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "0.28.0"
|
||||
version = "0.29.1"
|
||||
description = "Accelerate"
|
||||
optional = true
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "accelerate-0.28.0-py3-none-any.whl", hash = "sha256:8ae25f8a8dc4cf12283842c469113836300545fb0dfa46fef331fb0a2ac8b421"},
|
||||
{file = "accelerate-0.28.0.tar.gz", hash = "sha256:32019a49f4b3a85cc179ac4e38e9e2971f1a997dee026be0512816499464c4d5"},
|
||||
{file = "accelerate-0.29.1-py3-none-any.whl", hash = "sha256:7eda0c8bc62bc59129103310f1272a0fb7b3ebc55fc8920cfe1c102db30aca58"},
|
||||
{file = "accelerate-0.29.1.tar.gz", hash = "sha256:d1d0e5a591177891812fd6d1bc843af191e1192c80e5180258f52fefcb653a9f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -21,14 +21,14 @@ safetensors = ">=0.3.1"
|
|||
torch = ">=1.10.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed (<0.13.0)", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
|
||||
dev = ["bitsandbytes", "black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.2.1,<0.3.0)", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
|
||||
quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.2.1,<0.3.0)"]
|
||||
rich = ["rich"]
|
||||
sagemaker = ["sagemaker"]
|
||||
test-dev = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
|
||||
test-dev = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
|
||||
test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist"]
|
||||
test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"]
|
||||
testing = ["bitsandbytes", "datasets", "deepspeed (<0.13.0)", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
|
||||
testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
|
@ -471,18 +471,18 @@ test = ["pytest (>=6)"]
|
|||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.13.1"
|
||||
version = "3.13.3"
|
||||
description = "A platform independent file lock."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"},
|
||||
{file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"},
|
||||
{file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"},
|
||||
{file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"]
|
||||
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
|
||||
docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
|
||||
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
|
||||
typing = ["typing-extensions (>=4.8)"]
|
||||
|
||||
[[package]]
|
||||
|
@ -1512,14 +1512,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "nvidia-nvjitlink-cu12"
|
||||
version = "12.4.99"
|
||||
version = "12.4.127"
|
||||
description = "Nvidia JIT LTO Library"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1806,13 +1805,13 @@ xml = ["lxml (>=4.9.2)"]
|
|||
|
||||
[[package]]
|
||||
name = "peft"
|
||||
version = "0.9.0"
|
||||
version = "0.10.0"
|
||||
description = "Parameter-Efficient Fine-Tuning (PEFT)"
|
||||
optional = true
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "peft-0.9.0-py3-none-any.whl", hash = "sha256:d14223fee6050c53593733e8f763d94c13577e1220987f59ae473d988f2ccd91"},
|
||||
{file = "peft-0.9.0.tar.gz", hash = "sha256:3b8d09dff94d1bfa72e064cb26af5952fd82428e2bcce432cfaf091f5035b04b"},
|
||||
{file = "peft-0.10.0-py3-none-any.whl", hash = "sha256:d5249c97e818d3e31f92553c73c2953acd0ec12649b8b749afff7152cbc86cbb"},
|
||||
{file = "peft-0.10.0.tar.gz", hash = "sha256:36a7628c15f88d37abb26cfc74c22468f9037ee02e9c9b65de943cfe7c672049"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1835,79 +1834,80 @@ test = ["black", "datasets", "diffusers (<0.21.0)", "hf-doc-builder", "parameter
|
|||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "10.2.0"
|
||||
version = "10.3.0"
|
||||
description = "Python Imaging Library (Fork)"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pillow-10.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:7823bdd049099efa16e4246bdf15e5a13dbb18a51b68fa06d6c1d4d8b99a796e"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:83b2021f2ade7d1ed556bc50a399127d7fb245e725aa0113ebd05cfe88aaf588"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fad5ff2f13d69b7e74ce5b4ecd12cc0ec530fcee76356cac6742785ff71c452"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da2b52b37dad6d9ec64e653637a096905b258d2fc2b984c41ae7d08b938a67e4"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:47c0995fc4e7f79b5cfcab1fc437ff2890b770440f7696a3ba065ee0fd496563"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:322bdf3c9b556e9ffb18f93462e5f749d3444ce081290352c6070d014c93feb2"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:51f1a1bffc50e2e9492e87d8e09a17c5eea8409cda8d3f277eb6edc82813c17c"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:69ffdd6120a4737710a9eee73e1d2e37db89b620f702754b8f6e62594471dee0"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-win32.whl", hash = "sha256:c6dafac9e0f2b3c78df97e79af707cdc5ef8e88208d686a4847bab8266870023"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:aebb6044806f2e16ecc07b2a2637ee1ef67a11840a66752751714a0d924adf72"},
|
||||
{file = "pillow-10.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:7049e301399273a0136ff39b84c3678e314f2158f50f517bc50285fb5ec847ad"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:35bb52c37f256f662abdfa49d2dfa6ce5d93281d323a9af377a120e89a9eafb5"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9c23f307202661071d94b5e384e1e1dc7dfb972a28a2310e4ee16103e66ddb67"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:773efe0603db30c281521a7c0214cad7836c03b8ccff897beae9b47c0b657d61"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11fa2e5984b949b0dd6d7a94d967743d87c577ff0b83392f17cb3990d0d2fd6e"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:716d30ed977be8b37d3ef185fecb9e5a1d62d110dfbdcd1e2a122ab46fddb03f"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a086c2af425c5f62a65e12fbf385f7c9fcb8f107d0849dba5839461a129cf311"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c8de2789052ed501dd829e9cae8d3dcce7acb4777ea4a479c14521c942d395b1"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:609448742444d9290fd687940ac0b57fb35e6fd92bdb65386e08e99af60bf757"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-win32.whl", hash = "sha256:823ef7a27cf86df6597fa0671066c1b596f69eba53efa3d1e1cb8b30f3533068"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:1da3b2703afd040cf65ec97efea81cfba59cdbed9c11d8efc5ab09df9509fc56"},
|
||||
{file = "pillow-10.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:edca80cbfb2b68d7b56930b84a0e45ae1694aeba0541f798e908a49d66b837f1"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:1b5e1b74d1bd1b78bc3477528919414874748dd363e6272efd5abf7654e68bef"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0eae2073305f451d8ecacb5474997c08569fb4eb4ac231ffa4ad7d342fdc25ac"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7c2286c23cd350b80d2fc9d424fc797575fb16f854b831d16fd47ceec078f2c"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e23412b5c41e58cec602f1135c57dfcf15482013ce6e5f093a86db69646a5aa"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:52a50aa3fb3acb9cf7213573ef55d31d6eca37f5709c69e6858fe3bc04a5c2a2"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:127cee571038f252a552760076407f9cff79761c3d436a12af6000cd182a9d04"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8d12251f02d69d8310b046e82572ed486685c38f02176bd08baf216746eb947f"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:54f1852cd531aa981bc0965b7d609f5f6cc8ce8c41b1139f6ed6b3c54ab82bfb"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-win32.whl", hash = "sha256:257d8788df5ca62c980314053197f4d46eefedf4e6175bc9412f14412ec4ea2f"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:154e939c5f0053a383de4fd3d3da48d9427a7e985f58af8e94d0b3c9fcfcf4f9"},
|
||||
{file = "pillow-10.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:f379abd2f1e3dddb2b61bc67977a6b5a0a3f7485538bcc6f39ec76163891ee48"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8373c6c251f7ef8bda6675dd6d2b3a0fcc31edf1201266b5cf608b62a37407f9"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:870ea1ada0899fd0b79643990809323b389d4d1d46c192f97342eeb6ee0b8483"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4b6b1e20608493548b1f32bce8cca185bf0480983890403d3b8753e44077129"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3031709084b6e7852d00479fd1d310b07d0ba82765f973b543c8af5061cf990e"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:3ff074fc97dd4e80543a3e91f69d58889baf2002b6be64347ea8cf5533188213"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:cb4c38abeef13c61d6916f264d4845fab99d7b711be96c326b84df9e3e0ff62d"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b1b3020d90c2d8e1dae29cf3ce54f8094f7938460fb5ce8bc5c01450b01fbaf6"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:170aeb00224ab3dc54230c797f8404507240dd868cf52066f66a41b33169bdbe"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-win32.whl", hash = "sha256:c4225f5220f46b2fde568c74fca27ae9771536c2e29d7c04f4fb62c83275ac4e"},
|
||||
{file = "pillow-10.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:0689b5a8c5288bc0504d9fcee48f61a6a586b9b98514d7d29b840143d6734f39"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:b792a349405fbc0163190fde0dc7b3fef3c9268292586cf5645598b48e63dc67"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c570f24be1e468e3f0ce7ef56a89a60f0e05b30a3669a459e419c6eac2c35364"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8ecd059fdaf60c1963c58ceb8997b32e9dc1b911f5da5307aab614f1ce5c2fb"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c365fd1703040de1ec284b176d6af5abe21b427cb3a5ff68e0759e1e313a5e7e"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:70c61d4c475835a19b3a5aa42492409878bbca7438554a1f89d20d58a7c75c01"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b6f491cdf80ae540738859d9766783e3b3c8e5bd37f5dfa0b76abdecc5081f13"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d189550615b4948f45252d7f005e53c2040cea1af5b60d6f79491a6e147eef7"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:49d9ba1ed0ef3e061088cd1e7538a0759aab559e2e0a80a36f9fd9d8c0c21591"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-win32.whl", hash = "sha256:babf5acfede515f176833ed6028754cbcd0d206f7f614ea3447d67c33be12516"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:0304004f8067386b477d20a518b50f3fa658a28d44e4116970abfcd94fac34a8"},
|
||||
{file = "pillow-10.2.0-cp39-cp39-win_arm64.whl", hash = "sha256:0fb3e7fc88a14eacd303e90481ad983fd5b69c761e9e6ef94c983f91025da869"},
|
||||
{file = "pillow-10.2.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:322209c642aabdd6207517e9739c704dc9f9db943015535783239022002f054a"},
|
||||
{file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eedd52442c0a5ff4f887fab0c1c0bb164d8635b32c894bc1faf4c618dd89df2"},
|
||||
{file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb28c753fd5eb3dd859b4ee95de66cc62af91bcff5db5f2571d32a520baf1f04"},
|
||||
{file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:33870dc4653c5017bf4c8873e5488d8f8d5f8935e2f1fb9a2208c47cdd66efd2"},
|
||||
{file = "pillow-10.2.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3c31822339516fb3c82d03f30e22b1d038da87ef27b6a78c9549888f8ceda39a"},
|
||||
{file = "pillow-10.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a2b56ba36e05f973d450582fb015594aaa78834fefe8dfb8fcd79b93e64ba4c6"},
|
||||
{file = "pillow-10.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d8e6aeb9201e655354b3ad049cb77d19813ad4ece0df1249d3c793de3774f8c7"},
|
||||
{file = "pillow-10.2.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:2247178effb34a77c11c0e8ac355c7a741ceca0a732b27bf11e747bbc950722f"},
|
||||
{file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15587643b9e5eb26c48e49a7b33659790d28f190fc514a322d55da2fb5c2950e"},
|
||||
{file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753cd8f2086b2b80180d9b3010dd4ed147efc167c90d3bf593fe2af21265e5a5"},
|
||||
{file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:7c8f97e8e7a9009bcacbe3766a36175056c12f9a44e6e6f2d5caad06dcfbf03b"},
|
||||
{file = "pillow-10.2.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d1b35bcd6c5543b9cb547dee3150c93008f8dd0f1fef78fc0cd2b141c5baf58a"},
|
||||
{file = "pillow-10.2.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe4c15f6c9285dc54ce6553a3ce908ed37c8f3825b5a51a15c91442bb955b868"},
|
||||
{file = "pillow-10.2.0.tar.gz", hash = "sha256:e87f0b2c78157e12d7686b27d63c070fd65d994e8ddae6f328e0dcf4a0cd007e"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:90b9e29824800e90c84e4022dd5cc16eb2d9605ee13f05d47641eb183cd73d45"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2c405445c79c3f5a124573a051062300936b0281fee57637e706453e452746c"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78618cdbccaa74d3f88d0ad6cb8ac3007f1a6fa5c6f19af64b55ca170bfa1edf"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261ddb7ca91fcf71757979534fb4c128448b5b4c55cb6152d280312062f69599"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ce49c67f4ea0609933d01c0731b34b8695a7a748d6c8d186f95e7d085d2fe475"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b14f16f94cbc61215115b9b1236f9c18403c15dd3c52cf629072afa9d54c1cbf"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d33891be6df59d93df4d846640f0e46f1a807339f09e79a8040bc887bdcd7ed3"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b50811d664d392f02f7761621303eba9d1b056fb1868c8cdf4231279645c25f5"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-win32.whl", hash = "sha256:ca2870d5d10d8726a27396d3ca4cf7976cec0f3cb706debe88e3a5bd4610f7d2"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:f0d0591a0aeaefdaf9a5e545e7485f89910c977087e7de2b6c388aec32011e9f"},
|
||||
{file = "pillow-10.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:ccce24b7ad89adb5a1e34a6ba96ac2530046763912806ad4c247356a8f33a67b"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:5f77cf66e96ae734717d341c145c5949c63180842a545c47a0ce7ae52ca83795"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4b878386c4bf293578b48fc570b84ecfe477d3b77ba39a6e87150af77f40c57"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdcbb4068117dfd9ce0138d068ac512843c52295ed996ae6dd1faf537b6dbc27"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9797a6c8fe16f25749b371c02e2ade0efb51155e767a971c61734b1bf6293994"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:9e91179a242bbc99be65e139e30690e081fe6cb91a8e77faf4c409653de39451"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:1b87bd9d81d179bd8ab871603bd80d8645729939f90b71e62914e816a76fc6bd"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:81d09caa7b27ef4e61cb7d8fbf1714f5aec1c6b6c5270ee53504981e6e9121ad"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:048ad577748b9fa4a99a0548c64f2cb8d672d5bf2e643a739ac8faff1164238c"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-win32.whl", hash = "sha256:7161ec49ef0800947dc5570f86568a7bb36fa97dd09e9827dc02b718c5643f09"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:8eb0908e954d093b02a543dc963984d6e99ad2b5e36503d8a0aaf040505f747d"},
|
||||
{file = "pillow-10.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:4e6f7d1c414191c1199f8996d3f2282b9ebea0945693fb67392c75a3a320941f"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:e46f38133e5a060d46bd630faa4d9fa0202377495df1f068a8299fd78c84de84"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50b8eae8f7334ec826d6eeffaeeb00e36b5e24aa0b9df322c247539714c6df19"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d3bea1c75f8c53ee4d505c3e67d8c158ad4df0d83170605b50b64025917f338"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19aeb96d43902f0a783946a0a87dbdad5c84c936025b8419da0a0cd7724356b1"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74d28c17412d9caa1066f7a31df8403ec23d5268ba46cd0ad2c50fb82ae40462"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ff61bfd9253c3915e6d41c651d5f962da23eda633cf02262990094a18a55371a"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d886f5d353333b4771d21267c7ecc75b710f1a73d72d03ca06df49b09015a9ef"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4b5ec25d8b17217d635f8935dbc1b9aa5907962fae29dff220f2659487891cd3"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-win32.whl", hash = "sha256:51243f1ed5161b9945011a7360e997729776f6e5d7005ba0c6879267d4c5139d"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:412444afb8c4c7a6cc11a47dade32982439925537e483be7c0ae0cf96c4f6a0b"},
|
||||
{file = "pillow-10.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:798232c92e7665fe82ac085f9d8e8ca98826f8e27859d9a96b41d519ecd2e49a"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:4eaa22f0d22b1a7e93ff0a596d57fdede2e550aecffb5a1ef1106aaece48e96b"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cd5e14fbf22a87321b24c88669aad3a51ec052eb145315b3da3b7e3cc105b9a2"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1530e8f3a4b965eb6a7785cf17a426c779333eb62c9a7d1bbcf3ffd5bf77a4aa"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d512aafa1d32efa014fa041d38868fda85028e3f930a96f85d49c7d8ddc0383"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:339894035d0ede518b16073bdc2feef4c991ee991a29774b33e515f1d308e08d"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:aa7e402ce11f0885305bfb6afb3434b3cd8f53b563ac065452d9d5654c7b86fd"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:0ea2a783a2bdf2a561808fe4a7a12e9aa3799b701ba305de596bc48b8bdfce9d"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c78e1b00a87ce43bb37642c0812315b411e856a905d58d597750eb79802aaaa3"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-win32.whl", hash = "sha256:72d622d262e463dfb7595202d229f5f3ab4b852289a1cd09650362db23b9eb0b"},
|
||||
{file = "pillow-10.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:2034f6759a722da3a3dbd91a81148cf884e91d1b747992ca288ab88c1de15999"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2ed854e716a89b1afcedea551cd85f2eb2a807613752ab997b9974aaa0d56936"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc1a390a82755a8c26c9964d457d4c9cbec5405896cba94cf51f36ea0d855002"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4203efca580f0dd6f882ca211f923168548f7ba334c189e9eab1178ab840bf60"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3102045a10945173d38336f6e71a8dc71bcaeed55c3123ad4af82c52807b9375"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6fb1b30043271ec92dc65f6d9f0b7a830c210b8a96423074b15c7bc999975f57"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:1dfc94946bc60ea375cc39cff0b8da6c7e5f8fcdc1d946beb8da5c216156ddd8"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b09b86b27a064c9624d0a6c54da01c1beaf5b6cadfa609cf63789b1d08a797b9"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d3b2348a78bc939b4fed6552abfd2e7988e0f81443ef3911a4b8498ca084f6eb"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-win32.whl", hash = "sha256:45ebc7b45406febf07fef35d856f0293a92e7417ae7933207e90bf9090b70572"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:0ba26351b137ca4e0db0342d5d00d2e355eb29372c05afd544ebf47c0956ffeb"},
|
||||
{file = "pillow-10.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:50fd3f6b26e3441ae07b7c979309638b72abc1a25da31a81a7fbd9495713ef4f"},
|
||||
{file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_10_10_x86_64.whl", hash = "sha256:6b02471b72526ab8a18c39cb7967b72d194ec53c1fd0a70b050565a0f366d355"},
|
||||
{file = "pillow-10.3.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8ab74c06ffdab957d7670c2a5a6e1a70181cd10b727cd788c4dd9005b6a8acd9"},
|
||||
{file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:048eeade4c33fdf7e08da40ef402e748df113fd0b4584e32c4af74fe78baaeb2"},
|
||||
{file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2ec1e921fd07c7cda7962bad283acc2f2a9ccc1b971ee4b216b75fad6f0463"},
|
||||
{file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c8e73e99da7db1b4cad7f8d682cf6abad7844da39834c288fbfa394a47bbced"},
|
||||
{file = "pillow-10.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:16563993329b79513f59142a6b02055e10514c1a8e86dca8b48a893e33cf91e3"},
|
||||
{file = "pillow-10.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dd78700f5788ae180b5ee8902c6aea5a5726bac7c364b202b4b3e3ba2d293170"},
|
||||
{file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:aff76a55a8aa8364d25400a210a65ff59d0168e0b4285ba6bf2bd83cf675ba32"},
|
||||
{file = "pillow-10.3.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:b7bc2176354defba3edc2b9a777744462da2f8e921fbaf61e52acb95bafa9828"},
|
||||
{file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:793b4e24db2e8742ca6423d3fde8396db336698c55cd34b660663ee9e45ed37f"},
|
||||
{file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93480005693d247f8346bc8ee28c72a2191bdf1f6b5db469c096c0c867ac015"},
|
||||
{file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c83341b89884e2b2e55886e8fbbf37c3fa5efd6c8907124aeb72f285ae5696e5"},
|
||||
{file = "pillow-10.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1a1d1915db1a4fdb2754b9de292642a39a7fb28f1736699527bb649484fb966a"},
|
||||
{file = "pillow-10.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a0eaa93d054751ee9964afa21c06247779b90440ca41d184aeb5d410f20ff591"},
|
||||
{file = "pillow-10.3.0.tar.gz", hash = "sha256:9d2455fbf44c914840c793e89aa82d0e1763a14253a000743719ae5946814b2d"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
@ -2222,6 +2222,7 @@ files = [
|
|||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
|
@ -2636,45 +2637,45 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"]
|
|||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.12.0"
|
||||
version = "1.13.0"
|
||||
description = "Fundamental algorithms for scientific computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78e4402e140879387187f7f25d91cc592b3501a2e51dfb320f48dfb73565f10b"},
|
||||
{file = "scipy-1.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5f00ebaf8de24d14b8449981a2842d404152774c1a1d880c901bf454cb8e2a1"},
|
||||
{file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e53958531a7c695ff66c2e7bb7b79560ffdc562e2051644c5576c39ff8efb563"},
|
||||
{file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e32847e08da8d895ce09d108a494d9eb78974cf6de23063f93306a3e419960c"},
|
||||
{file = "scipy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4c1020cad92772bf44b8e4cdabc1df5d87376cb219742549ef69fc9fd86282dd"},
|
||||
{file = "scipy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:75ea2a144096b5e39402e2ff53a36fecfd3b960d786b7efd3c180e29c39e53f2"},
|
||||
{file = "scipy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:408c68423f9de16cb9e602528be4ce0d6312b05001f3de61fe9ec8b1263cad08"},
|
||||
{file = "scipy-1.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5adfad5dbf0163397beb4aca679187d24aec085343755fcdbdeb32b3679f254c"},
|
||||
{file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3003652496f6e7c387b1cf63f4bb720951cfa18907e998ea551e6de51a04467"},
|
||||
{file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b8066bce124ee5531d12a74b617d9ac0ea59245246410e19bca549656d9a40a"},
|
||||
{file = "scipy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8bee4993817e204d761dba10dbab0774ba5a8612e57e81319ea04d84945375ba"},
|
||||
{file = "scipy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:a24024d45ce9a675c1fb8494e8e5244efea1c7a09c60beb1eeb80373d0fecc70"},
|
||||
{file = "scipy-1.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e7e76cc48638228212c747ada851ef355c2bb5e7f939e10952bc504c11f4e372"},
|
||||
{file = "scipy-1.12.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f7ce148dffcd64ade37b2df9315541f9adad6efcaa86866ee7dd5db0c8f041c3"},
|
||||
{file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c39f92041f490422924dfdb782527a4abddf4707616e07b021de33467f917bc"},
|
||||
{file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7ebda398f86e56178c2fa94cad15bf457a218a54a35c2a7b4490b9f9cb2676c"},
|
||||
{file = "scipy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:95e5c750d55cf518c398a8240571b0e0782c2d5a703250872f36eaf737751338"},
|
||||
{file = "scipy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:e646d8571804a304e1da01040d21577685ce8e2db08ac58e543eaca063453e1c"},
|
||||
{file = "scipy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:913d6e7956c3a671de3b05ccb66b11bc293f56bfdef040583a7221d9e22a2e35"},
|
||||
{file = "scipy-1.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba1b0c7256ad75401c73e4b3cf09d1f176e9bd4248f0d3112170fb2ec4db067"},
|
||||
{file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:730badef9b827b368f351eacae2e82da414e13cf8bd5051b4bdfd720271a5371"},
|
||||
{file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6546dc2c11a9df6926afcbdd8a3edec28566e4e785b915e849348c6dd9f3f490"},
|
||||
{file = "scipy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:196ebad3a4882081f62a5bf4aeb7326aa34b110e533aab23e4374fcccb0890dc"},
|
||||
{file = "scipy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:b360f1b6b2f742781299514e99ff560d1fe9bd1bff2712894b52abe528d1fd1e"},
|
||||
{file = "scipy-1.12.0.tar.gz", hash = "sha256:4bf5abab8a36d20193c698b0f1fc282c1d083c94723902c447e5d2f1780936a3"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"},
|
||||
{file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"},
|
||||
{file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"},
|
||||
{file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"},
|
||||
{file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"},
|
||||
{file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.22.4,<1.29.0"
|
||||
numpy = ">=1.22.4,<2.3"
|
||||
|
||||
[package.extras]
|
||||
dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
|
||||
doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
|
||||
test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||
dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
|
||||
doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"]
|
||||
test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||
|
||||
[[package]]
|
||||
name = "sentencepiece"
|
||||
|
@ -2922,36 +2923,36 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.2.1"
|
||||
version = "2.2.2"
|
||||
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||
optional = true
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "torch-2.2.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:8d3bad336dd2c93c6bcb3268e8e9876185bda50ebde325ef211fb565c7d15273"},
|
||||
{file = "torch-2.2.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5297f13370fdaca05959134b26a06a7f232ae254bf2e11a50eddec62525c9006"},
|
||||
{file = "torch-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:5f5dee8433798888ca1415055f5e3faf28a3bad660e4c29e1014acd3275ab11a"},
|
||||
{file = "torch-2.2.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b6d78338acabf1fb2e88bf4559d837d30230cf9c3e4337261f4d83200df1fcbe"},
|
||||
{file = "torch-2.2.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:6ab3ea2e29d1aac962e905142bbe50943758f55292f1b4fdfb6f4792aae3323e"},
|
||||
{file = "torch-2.2.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:d86664ec85902967d902e78272e97d1aff1d331f7619d398d3ffab1c9b8e9157"},
|
||||
{file = "torch-2.2.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d6227060f268894f92c61af0a44c0d8212e19cb98d05c20141c73312d923bc0a"},
|
||||
{file = "torch-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:77e990af75fb1675490deb374d36e726f84732cd5677d16f19124934b2409ce9"},
|
||||
{file = "torch-2.2.1-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:46085e328d9b738c261f470231e987930f4cc9472d9ffb7087c7a1343826ac51"},
|
||||
{file = "torch-2.2.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:2d9e7e5ecbb002257cf98fae13003abbd620196c35f85c9e34c2adfb961321ec"},
|
||||
{file = "torch-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ada53aebede1c89570e56861b08d12ba4518a1f8b82d467c32665ec4d1f4b3c8"},
|
||||
{file = "torch-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:be21d4c41ecebed9e99430dac87de1439a8c7882faf23bba7fea3fea7b906ac1"},
|
||||
{file = "torch-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:79848f46196750367dcdf1d2132b722180b9d889571e14d579ae82d2f50596c5"},
|
||||
{file = "torch-2.2.1-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:7ee804847be6be0032fbd2d1e6742fea2814c92bebccb177f0d3b8e92b2d2b18"},
|
||||
{file = "torch-2.2.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:84b2fb322ab091039fdfe74e17442ff046b258eb5e513a28093152c5b07325a7"},
|
||||
{file = "torch-2.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5c0c83aa7d94569997f1f474595e808072d80b04d34912ce6f1a0e1c24b0c12a"},
|
||||
{file = "torch-2.2.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:91a1b598055ba06b2c386415d2e7f6ac818545e94c5def597a74754940188513"},
|
||||
{file = "torch-2.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:8f93ddf3001ecec16568390b507652644a3a103baa72de3ad3b9c530e3277098"},
|
||||
{file = "torch-2.2.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:0e8bdd4c77ac2584f33ee14c6cd3b12767b4da508ec4eed109520be7212d1069"},
|
||||
{file = "torch-2.2.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:6a21bcd7076677c97ca7db7506d683e4e9db137e8420eb4a68fb67c3668232a7"},
|
||||
{file = "torch-2.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:f1b90ac61f862634039265cd0f746cc9879feee03ff962c803486301b778714b"},
|
||||
{file = "torch-2.2.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ed9e29eb94cd493b36bca9cb0b1fd7f06a0688215ad1e4b3ab4931726e0ec092"},
|
||||
{file = "torch-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:c47bc25744c743f3835831a20efdcfd60aeb7c3f9804a213f61e45803d16c2a5"},
|
||||
{file = "torch-2.2.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:0952549bcb43448c8d860d5e3e947dd18cbab491b14638e21750cb3090d5ad3e"},
|
||||
{file = "torch-2.2.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:26bd2272ec46fc62dcf7d24b2fb284d44fcb7be9d529ebf336b9860350d674ed"},
|
||||
{file = "torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585"},
|
||||
{file = "torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030"},
|
||||
{file = "torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5"},
|
||||
{file = "torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e"},
|
||||
{file = "torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2"},
|
||||
{file = "torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ad4c03b786e074f46606f4151c0a1e3740268bcf29fbd2fdf6666d66341c1dcb"},
|
||||
{file = "torch-2.2.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:32827fa1fbe5da8851686256b4cd94cc7b11be962862c2293811c94eea9457bf"},
|
||||
{file = "torch-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:f9ef0a648310435511e76905f9b89612e45ef2c8b023bee294f5e6f7e73a3e7c"},
|
||||
{file = "torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:95b9b44f3bcebd8b6cd8d37ec802048c872d9c567ba52c894bba90863a439059"},
|
||||
{file = "torch-2.2.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:49aa4126ede714c5aeef7ae92969b4b0bbe67f19665106463c39f22e0a1860d1"},
|
||||
{file = "torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:cf12cdb66c9c940227ad647bc9cf5dba7e8640772ae10dfe7569a0c1e2a28aca"},
|
||||
{file = "torch-2.2.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:89ddac2a8c1fb6569b90890955de0c34e1724f87431cacff4c1979b5f769203c"},
|
||||
{file = "torch-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:451331406b760f4b1ab298ddd536486ab3cfb1312614cfe0532133535be60bea"},
|
||||
{file = "torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:eb4d6e9d3663e26cd27dc3ad266b34445a16b54908e74725adb241aa56987533"},
|
||||
{file = "torch-2.2.2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:bf9558da7d2bf7463390b3b2a61a6a3dbb0b45b161ee1dd5ec640bf579d479fc"},
|
||||
{file = "torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd2bf7697c9e95fb5d97cc1d525486d8cf11a084c6af1345c2c2c22a6b0029d0"},
|
||||
{file = "torch-2.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b421448d194496e1114d87a8b8d6506bce949544e513742b097e2ab8f7efef32"},
|
||||
{file = "torch-2.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:3dbcd563a9b792161640c0cffe17e3270d85e8f4243b1f1ed19cca43d28d235b"},
|
||||
{file = "torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:31f4310210e7dda49f1fb52b0ec9e59382cfcb938693f6d5378f25b43d7c1d29"},
|
||||
{file = "torch-2.2.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c795feb7e8ce2e0ef63f75f8e1ab52e7fd5e1a4d7d0c31367ade1e3de35c9e95"},
|
||||
{file = "torch-2.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a6e5770d68158d07456bfcb5318b173886f579fdfbf747543901ce718ea94782"},
|
||||
{file = "torch-2.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:67dcd726edff108e2cd6c51ff0e416fd260c869904de95750e80051358680d24"},
|
||||
{file = "torch-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:539d5ef6c4ce15bd3bd47a7b4a6e7c10d49d4d21c0baaa87c7d2ef8698632dfb"},
|
||||
{file = "torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:dff696de90d6f6d1e8200e9892861fd4677306d0ef604cb18f2134186f719f82"},
|
||||
{file = "torch-2.2.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:3a4dd910663fd7a124c056c878a52c2b0be4a5a424188058fe97109d4436ee42"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3000,13 +3001,13 @@ telegram = ["requests"]
|
|||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.39.0"
|
||||
version = "4.39.3"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "transformers-4.39.0-py3-none-any.whl", hash = "sha256:7801785b1f016d667467e8c372c1c3653c18fe32ba97952059e3bea79ba22b08"},
|
||||
{file = "transformers-4.39.0.tar.gz", hash = "sha256:517a13cd633b10bea01c92ab0b3059762872c7c29da3d223db9d28e926fe330d"},
|
||||
{file = "transformers-4.39.3-py3-none-any.whl", hash = "sha256:7838034a12cca3168247f9d2d1dba6724c9de3ae0f73a108258c6b8fc5912601"},
|
||||
{file = "transformers-4.39.3.tar.gz", hash = "sha256:2586e5ff4150f122716fc40f5530e92871befc051848fbe82600969c535b762d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3111,13 +3112,13 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=5.2,<6.0)", "isort (>=5.0.6,<6.
|
|||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.10.0"
|
||||
version = "4.11.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.8+"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"},
|
||||
{file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
|
||||
{file = "typing_extensions-4.11.0-py3-none-any.whl", hash = "sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a"},
|
||||
{file = "typing_extensions-4.11.0.tar.gz", hash = "sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -3472,4 +3473,4 @@ torch = ["torch"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "b19c9b3f63778f3fc862907c9377d9389b8d0ea8aa02f69bdba628f4829bc62d"
|
||||
content-hash = "2bbb3d5acafe4bd1a616106ad3956058078a1f54c9a906a0065ee474f5af407d"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "1.4.5"
|
||||
version = "2.0.0"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
|
@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
|
|||
grpcio-reflection = "^1.51.1"
|
||||
grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.28.0", optional = true }
|
||||
accelerate = { version = "^0.29.1", optional = true }
|
||||
bitsandbytes = { version = "^0.43.0", optional = true }
|
||||
safetensors = "^0.4"
|
||||
loguru = "^0.6.0"
|
||||
|
@ -26,11 +26,11 @@ hf-transfer = "^0.1.2"
|
|||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.15.0"
|
||||
huggingface-hub = "^0.19.3"
|
||||
transformers = "^4.38"
|
||||
transformers = "^4.39"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
peft = { version = "^0.9", optional = true }
|
||||
peft = { version = "^0.10", optional = true }
|
||||
torch = { version = "^2.1.1", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
|
|
|
@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
|||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -27,20 +27,20 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
|||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -5,7 +5,7 @@ click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
|||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
@ -27,20 +27,20 @@ opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13
|
|||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.39.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -19,6 +19,7 @@ class Quantization(str, Enum):
|
|||
gptq = "gptq"
|
||||
awq = "awq"
|
||||
eetq = "eetq"
|
||||
fp8 = "fp8"
|
||||
|
||||
|
||||
class Dtype(str, Enum):
|
||||
|
@ -249,6 +250,13 @@ def download_weights(
|
|||
local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
|
||||
|
||||
if auto_convert:
|
||||
if not trust_remote_code:
|
||||
logger.warning(
|
||||
f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
|
||||
f"Pickle files are unsafe and can essentially contain remote code execution!"
|
||||
f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"No safetensors weights found for model {model_id} at revision {revision}. "
|
||||
f"Converting PyTorch weights to safetensors."
|
||||
|
|
|
@ -23,6 +23,10 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||
method_name = method_name.split("/")[-1]
|
||||
logger.exception(f"Method {method_name} encountered an error.")
|
||||
|
||||
# Runtime Error cannot be recovered from
|
||||
if isinstance(err, RuntimeError):
|
||||
exit(1)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ try:
|
|||
FlashSantacoderSharded,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.llava_next import LlavaNext
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
|
@ -144,7 +145,7 @@ def get_model(
|
|||
if speculate is not None:
|
||||
if speculate > speculate_medusa:
|
||||
raise RuntimeError(
|
||||
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||
)
|
||||
else:
|
||||
set_speculate(speculate)
|
||||
|
@ -186,6 +187,14 @@ def get_model(
|
|||
raise RuntimeError(
|
||||
f"Could not determine model type for {model_id} revision {revision}"
|
||||
)
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is not None and quantize is None:
|
||||
method = quantization_config.get("quant_method", None)
|
||||
if method in {"gptq", "awq"}:
|
||||
logger.info(f"Auto selecting quantization method {method}")
|
||||
quantize = method
|
||||
else:
|
||||
logger.info(f"Unknown quantization method {method}")
|
||||
|
||||
if model_type == "ssm":
|
||||
return Mamba(
|
||||
|
@ -571,6 +580,19 @@ def get_model(
|
|||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == "llava_next":
|
||||
if FLASH_ATTENTION:
|
||||
return LlavaNext(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("LlavaNext"))
|
||||
|
||||
if sharded:
|
||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||
if quantize == "gptq":
|
||||
|
|
|
@ -47,7 +47,7 @@ class CacheManager:
|
|||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
0, num_blocks * self.block_size, dtype=torch.int64
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(
|
||||
|
@ -59,9 +59,10 @@ class CacheManager:
|
|||
):
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= blocks
|
||||
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
if blocks > len(free_block_indices):
|
||||
raise RuntimeError(
|
||||
f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
)
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[:blocks]
|
||||
|
|
|
@ -0,0 +1,827 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_create_4d_causal_attention_mask,
|
||||
_prepare_4d_attention_mask,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
)
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
# TODO Should we TP this ?
|
||||
self.class_embedding = weights.get_tensor(f"{prefix}.class_embedding")
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.position_embedding", weights=weights
|
||||
)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, device=weights.device).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values.to(dtype=target_dtype)
|
||||
) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.max_position_embeddings, embed_dim
|
||||
)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = (
|
||||
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = self.embed_dim // self.num_heads
|
||||
if self.head_size * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
self.scale = self.head_size**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return (
|
||||
tensor.view(bsz, seq_len, self.num_heads, self.head_size)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
]
|
||||
* 3,
|
||||
dim=2,
|
||||
)
|
||||
query_states = query_states * self.scale
|
||||
key_states = self._shape(key_states, -1, bsz)
|
||||
value_states = self._shape(value_states, -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_size)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ causal_attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ attention_mask
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_probs = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class CLIPMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = CLIPAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = CLIPMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.layer_norm2 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPPreTrainedModel(nn.Module):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = CLIPConfig
|
||||
base_model_prefix = "clip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
|
||||
CLIP_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
"""
|
||||
|
||||
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
"""
|
||||
|
||||
CLIP_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
"""
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`CLIPEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: CLIPConfig
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
CLIPEncoderLayer(
|
||||
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = CLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
# For `pooled_output` computation
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
||||
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = _create_4d_causal_attention_mask(
|
||||
input_shape, hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(
|
||||
attention_mask, hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
if self.eos_token_id == 2:
|
||||
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
||||
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
||||
# ------------------------------------------------------------
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
|
||||
dim=-1
|
||||
),
|
||||
]
|
||||
else:
|
||||
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(
|
||||
last_hidden_state.shape[0], device=last_hidden_state.device
|
||||
),
|
||||
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
||||
(
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
|
||||
== self.eos_token_id
|
||||
)
|
||||
.int()
|
||||
.argmax(dim=-1),
|
||||
]
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
class CLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = CLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CLIPTextModel
|
||||
|
||||
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionTransformer(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPVisionConfig, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
)
|
||||
self.pre_layrnorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
self.encoder = CLIPEncoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
# self.post_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
)
|
||||
last_hidden_state = encoder_outputs
|
||||
# pooled_output = last_hidden_state[:, 0, :]
|
||||
# pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
# pooler_output=pooled_output,
|
||||
# hidden_states=encoder_outputs,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_model = CLIPVisionTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPVisionModel
|
||||
|
||||
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
|
||||
class CLIPModel(nn.Module):
|
||||
def __init__(self, prefix, config: CLIPConfig, weights):
|
||||
super().__init__()
|
||||
text_config = config.text_config
|
||||
vision_config = config.vision_config
|
||||
|
||||
self.projection_dim = config.projection_dim
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
|
||||
self.text_model = CLIPTextTransformer(text_config)
|
||||
self.vision_model = CLIPVisionTransformer(vision_config)
|
||||
|
||||
self.visual_projection = nn.Linear(
|
||||
self.vision_embed_dim, self.projection_dim, bias=False
|
||||
)
|
||||
self.text_projection = nn.Linear(
|
||||
self.text_embed_dim, self.projection_dim, bias=False
|
||||
)
|
||||
self.logit_scale = nn.Parameter(
|
||||
torch.tensor(self.config.logit_scale_init_value)
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
||||
>>> text_features = model.get_text_features(**inputs)
|
||||
```"""
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_features = self.text_projection(pooled_output)
|
||||
|
||||
return text_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Returns:
|
||||
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
||||
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> image_features = model.get_image_features(**inputs)
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, CLIPModel
|
||||
|
||||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(
|
||||
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
||||
... )
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
```"""
|
||||
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs[1]
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_per_image = logits_per_text.t()
|
||||
|
||||
return logits_per_image, logits_per_text
|
|
@ -23,10 +23,10 @@ import torch.distributed
|
|||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
|
@ -34,65 +34,106 @@ from text_generation_server.utils.layers import (
|
|||
PositionRotaryEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
||||
if IS_CUDA_SYSTEM:
|
||||
import dropout_layer_norm
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
|
||||
class CohereConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
|
||||
class CohereRotary(PositionRotaryEmbedding):
|
||||
def forward(
|
||||
self,
|
||||
vocab_size=256000,
|
||||
hidden_size=8192,
|
||||
intermediate_size=22528,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=8192,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=5,
|
||||
eos_token_id=255001,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
logit_scale=1.0,
|
||||
**kwargs,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
# Such controlflows may add some overhead.
|
||||
if IS_CUDA_SYSTEM:
|
||||
import rotary_emb
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
q1 = query[..., ::2]
|
||||
q2 = query[..., 1::2]
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.logit_scale = logit_scale
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
k1 = key[..., ::2]
|
||||
k2 = key[..., 1::2]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif IS_ROCM_SYSTEM:
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
head_size = query.shape[-1]
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
|
||||
|
||||
class CohereLayerNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps):
|
||||
super().__init__()
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
self.weight = nn.Parameter(weight)
|
||||
# Fake weights
|
||||
self.ones = weight.new_ones(weight.shape[1])
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
|
||||
hidden_states = hidden_states.reshape(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
hidden_states_minus_mean = hidden_states - mean
|
||||
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
None,
|
||||
self.ones,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# Required to apply one weight matrix per head
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
hidden_states = self.weight * hidden_states
|
||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
|
@ -154,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.rotary_emb = CohereRotary.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
|
@ -175,6 +216,22 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.use_qk_norm = config.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = CohereLayerNorm(
|
||||
prefix=f"{prefix}.q_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.k_norm = CohereLayerNorm(
|
||||
prefix=f"{prefix}.k_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
|
@ -199,21 +256,28 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
max_s,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
query, key, value = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
query = query.reshape(-1, self.head_size)
|
||||
key = key.reshape(-1, self.head_size)
|
||||
query = self.q_norm(query.contiguous())
|
||||
key = self.k_norm(key.contiguous())
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
key = key.view(-1, self.num_key_value_heads, self.head_size)
|
||||
value = value.view(-1, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
@ -223,8 +287,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
|
@ -298,7 +362,7 @@ class FlashCohereLayer(nn.Module):
|
|||
)
|
||||
self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
self.input_layernorm = FastLayerNorm.load_no_bias(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
|
@ -362,7 +426,7 @@ class FlashCohereModel(torch.nn.Module):
|
|||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
self.norm = FastLayerNorm.load_no_bias(
|
||||
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
|
|
|
@ -16,14 +16,15 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
if not IS_XPU_SYSTEM:
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
|
@ -37,14 +38,6 @@ from text_generation_server.utils.layers import (
|
|||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
HAS_MEGABLOCKS = True
|
||||
try:
|
||||
import stk
|
||||
import megablocks.ops as ops
|
||||
except ImportError:
|
||||
logger.warning("Dbrx: megablocks is not installed")
|
||||
HAS_MEGABLOCKS = False
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
|
@ -531,18 +524,6 @@ def round_up(x: torch.Tensor, value: int):
|
|||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
"""
|
||||
Built on the paper and library Megablocks as described in
|
||||
https://arxiv.org/abs/2211.15841. This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
dropped tokens). It's faster since it formulates MoE operations
|
||||
in terms of block-sparse operations to accomodate imbalanced
|
||||
assignments of tokens to experts, whereas standard MoE either
|
||||
(1) drop tokens at the cost of reduced performance or (2) set
|
||||
capacity factor to number of experts and thus waste computation
|
||||
and memory on padding.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: DbrxConfig, weights):
|
||||
super().__init__()
|
||||
self.moe_normalize_expert_weights = (
|
||||
|
@ -572,241 +553,40 @@ class BlockSparseMoE(nn.Module):
|
|||
)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
self.w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights)
|
||||
self.w2 = _load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||
self.v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights)
|
||||
|
||||
self.offsets = None
|
||||
self.offsets_block_rows = 0
|
||||
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.wv1 = torch.cat([w1, v1], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
# Calculate the number of bits needed to represent the expert indices
|
||||
# so that we can pass it to radix sort.
|
||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
||||
self.blocking = 128
|
||||
self.quantize_scatter_num_bits = -1
|
||||
|
||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
||||
padded_tokens, _ = x.size()
|
||||
assert padded_tokens % self.blocking == 0
|
||||
assert self.ffn_dim % self.blocking == 0
|
||||
|
||||
# Offsets for the sparse matrix. All rows have the
|
||||
# same number of nonzero blocks dictated by the
|
||||
# dimensionality of a single expert.
|
||||
block_rows = padded_tokens // self.blocking
|
||||
blocks_per_row = self.ffn_dim // self.blocking
|
||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||
self.offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
self.offsets_block_rows = block_rows
|
||||
offsets = self.offsets
|
||||
else:
|
||||
offsets = self.offsets[: block_rows + 1]
|
||||
|
||||
# Indices for the sparse matrix. The indices for
|
||||
# the intermediate matrix are dynamic depending
|
||||
# on the mapping of tokens to experts.
|
||||
column_indices = ops.topology(
|
||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
||||
)
|
||||
|
||||
# For now, use meta init to save the device memory.
|
||||
data = torch.empty(
|
||||
column_indices.numel(),
|
||||
self.blocking,
|
||||
self.blocking,
|
||||
dtype=x.dtype,
|
||||
device="meta",
|
||||
)
|
||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
||||
return stk.Matrix(
|
||||
shape,
|
||||
data,
|
||||
row_indices,
|
||||
column_indices,
|
||||
offsets,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||
# Sort the expert ids to produce the scatter/gather
|
||||
# indices for the permutation.
|
||||
# selected_experts = selected_experts.int()
|
||||
|
||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||
# and indices == how to sort tokens?
|
||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
||||
|
||||
# Histogram the expert ids to identify the number of
|
||||
# tokens routed to each expert.
|
||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
||||
|
||||
# Round the token counts up to the block size used in
|
||||
# the matrix muliplications. Caculate the starting
|
||||
# position of each bin.
|
||||
|
||||
# List of size num_experts
|
||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||
|
||||
# Cumulative selected experts per token
|
||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
||||
padded_bins = promote_scalar(padded_bins)
|
||||
# padded_bins => [128, 128, 256, ...]
|
||||
|
||||
# Calculate the bin bounds for the sorted tokens.
|
||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
||||
bins = promote_scalar(bins)
|
||||
# bins => [3, 3, 5, ...]
|
||||
|
||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||
|
||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
selected_experts, weights = select_experts(
|
||||
gate_logits, self.top_k, self.moe_normalize_expert_weights
|
||||
)
|
||||
|
||||
(
|
||||
indices,
|
||||
bin_ids,
|
||||
bins,
|
||||
padded_bins,
|
||||
_,
|
||||
) = self.indices_and_padded_bins(selected_experts)
|
||||
|
||||
# Permute tokens and pad to prepare expert computation
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
||||
|
||||
# Create the sparse matrix topology
|
||||
with torch.no_grad():
|
||||
topo = self.topology(x, padded_bins)
|
||||
|
||||
# Perform the expert computation
|
||||
# First Dense x Dense -> Sparse for w1 and v1,
|
||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||
x = stk.Matrix(
|
||||
topo.size(),
|
||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
||||
* stk.ops.sdd(x, self.v1.t(), topo).data,
|
||||
topo.row_indices,
|
||||
topo.column_indices,
|
||||
topo.offsets,
|
||||
topo.column_indices_t,
|
||||
topo.offsets_t,
|
||||
topo.block_offsets_t,
|
||||
)
|
||||
|
||||
# Then Sparse x Dense -> Dense for w2
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = stk.ops.dsd(x, self.w2)
|
||||
|
||||
# Permute back and remove padding
|
||||
# (sequence_length, model_dim)
|
||||
x = ops.padded_scatter(
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
indices,
|
||||
bin_ids,
|
||||
weights,
|
||||
bins,
|
||||
padded_bins,
|
||||
self.wv1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
self.quantize_scatter_num_bits,
|
||||
).view(*input_shape)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(x, group=self.process_group)
|
||||
|
||||
return x.view(*input_shape)
|
||||
|
||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
weights,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
weights.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
if self.moe_normalize_expert_weights:
|
||||
weights = weights / torch.norm(
|
||||
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
|
||||
)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
|
||||
# Permute to [num_experts, model_dim, ffn_dim]
|
||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
renormalize=self.moe_normalize_expert_weights,
|
||||
inplace=True,
|
||||
)
|
||||
v1 = self.v1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
|
||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, v1)
|
||||
|
||||
out = torch.bmm(
|
||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
)
|
||||
# Mask not selected experts
|
||||
out *= weights.t().view(self.num_experts, -1, 1)
|
||||
|
||||
# Sum experts
|
||||
out = out.sum(0)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
||||
return self.sparse_forward(x)
|
||||
# This is faster when there is not a lot of tokens
|
||||
return self.dense_forward(x)
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
|
|
|
@ -281,9 +281,8 @@ class LlamaMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
@ -337,27 +336,30 @@ class FlashLlamaLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.model.layers.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -368,7 +370,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
|
@ -376,8 +378,10 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
|
@ -406,13 +410,19 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
|
@ -426,10 +436,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -437,6 +449,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -285,9 +285,8 @@ class MistralMLP(nn.Module):
|
|||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
@ -343,27 +342,24 @@ class MistralLayer(nn.Module):
|
|||
|
||||
|
||||
class MistralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MistralLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -374,7 +370,7 @@ class MistralModel(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
|
@ -384,9 +380,8 @@ class MistralModel(torch.nn.Module):
|
|||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
|
@ -410,18 +405,27 @@ class MistralModel(torch.nn.Module):
|
|||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashMistralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MistralModel(config, weights)
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = MistralModel(
|
||||
prefix="model" if not prefix else f"{prefix}.model",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
|
@ -453,8 +457,9 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
# kernel requires the true values
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
|
|
@ -24,6 +24,10 @@ import torch.distributed
|
|||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
if not IS_XPU_SYSTEM:
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
@ -41,14 +45,6 @@ from text_generation_server.utils.layers import (
|
|||
get_linear,
|
||||
)
|
||||
|
||||
HAS_MEGABLOCKS = True
|
||||
try:
|
||||
import stk
|
||||
import megablocks.ops as ops
|
||||
except ImportError:
|
||||
logger.warning("Mixtral: megablocks is not installed")
|
||||
HAS_MEGABLOCKS = False
|
||||
|
||||
|
||||
class MixtralConfig(PretrainedConfig):
|
||||
model_type = "mixtral"
|
||||
|
@ -321,18 +317,6 @@ def round_up(x: torch.Tensor, value: int):
|
|||
|
||||
|
||||
class BlockSparseMoE(nn.Module):
|
||||
"""
|
||||
Built on the paper and library Megablocks as described in
|
||||
https://arxiv.org/abs/2211.15841. This implementation is
|
||||
strictly equivalent to standard MoE with full capacity (no
|
||||
dropped tokens). It's faster since it formulates MoE operations
|
||||
in terms of block-sparse operations to accomodate imbalanced
|
||||
assignments of tokens to experts, whereas standard MoE either
|
||||
(1) drop tokens at the cost of reduced performance or (2) set
|
||||
capacity factor to number of experts and thus waste computation
|
||||
and memory on padding.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
|
@ -357,236 +341,40 @@ class BlockSparseMoE(nn.Module):
|
|||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
||||
self.w1 = _load_experts(config, f"{prefix}.experts", "w1", weights)
|
||||
self.w2 = _load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
self.w3 = _load_experts(config, f"{prefix}.experts", "w3", weights)
|
||||
|
||||
self.offsets = None
|
||||
self.offsets_block_rows = 0
|
||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
||||
)
|
||||
self.w13 = torch.cat([w1, w3], dim=1)
|
||||
self.w2 = (
|
||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
# Calculate the number of bits needed to represent the expert indices
|
||||
# so that we can pass it to radix sort.
|
||||
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
||||
self.blocking = 128
|
||||
self.quantize_scatter_num_bits = -1
|
||||
|
||||
def topology(self, x: torch.Tensor, padded_bins: torch.Tensor):
|
||||
padded_tokens, _ = x.size()
|
||||
assert padded_tokens % self.blocking == 0
|
||||
assert self.ffn_dim % self.blocking == 0
|
||||
|
||||
# Offsets for the sparse matrix. All rows have the
|
||||
# same number of nonzero blocks dictated by the
|
||||
# dimensionality of a single expert.
|
||||
block_rows = padded_tokens // self.blocking
|
||||
blocks_per_row = self.ffn_dim // self.blocking
|
||||
if self.offsets is None or block_rows > self.offsets_block_rows:
|
||||
self.offsets = torch.arange(
|
||||
0,
|
||||
block_rows * blocks_per_row + 1,
|
||||
blocks_per_row,
|
||||
dtype=torch.int32,
|
||||
device=x.device,
|
||||
)
|
||||
self.offsets_block_rows = block_rows
|
||||
offsets = self.offsets
|
||||
else:
|
||||
offsets = self.offsets[: block_rows + 1]
|
||||
|
||||
# Indices for the sparse matrix. The indices for
|
||||
# the intermediate matrix are dynamic depending
|
||||
# on the mapping of tokens to experts.
|
||||
column_indices = ops.topology(
|
||||
padded_bins, self.blocking, block_rows, blocks_per_row
|
||||
)
|
||||
|
||||
# For now, use meta init to save the device memory.
|
||||
data = torch.empty(
|
||||
column_indices.numel(),
|
||||
self.blocking,
|
||||
self.blocking,
|
||||
dtype=x.dtype,
|
||||
device="meta",
|
||||
)
|
||||
shape = (padded_tokens, self.ffn_dim * self.num_experts)
|
||||
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
||||
return stk.Matrix(
|
||||
shape,
|
||||
data,
|
||||
row_indices,
|
||||
column_indices,
|
||||
offsets,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
def indices_and_padded_bins(self, selected_experts: torch.Tensor):
|
||||
# Sort the expert ids to produce the scatter/gather
|
||||
# indices for the permutation.
|
||||
# selected_experts = selected_experts.int()
|
||||
|
||||
# returns bin_ids == num of experts for this sequence ? == unique selected experts?
|
||||
# and indices == how to sort tokens?
|
||||
bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit)
|
||||
# bin_ids => [0, 0, 0, 2, 2, ...] => [num_tokens * top_k]
|
||||
# indices => [14, 32, 33, ...] => [num_tokens * top_k]
|
||||
|
||||
# Histogram the expert ids to identify the number of
|
||||
# tokens routed to each expert.
|
||||
tokens_per_expert = ops.histogram(selected_experts, self.num_experts)
|
||||
# tokens_per_expert => [3, 0, 2, ...] => [num_experts]
|
||||
|
||||
# Round the token counts up to the block size used in
|
||||
# the matrix muliplications. Caculate the starting
|
||||
# position of each bin.
|
||||
|
||||
# List of size num_experts
|
||||
padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
|
||||
# padded_tokens_per_expert => [128, O, 128, ...]
|
||||
|
||||
# Cumulative selected experts per token
|
||||
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
||||
padded_bins = promote_scalar(padded_bins)
|
||||
# padded_bins => [128, 128, 256, ...]
|
||||
|
||||
# Calculate the bin bounds for the sorted tokens.
|
||||
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
||||
bins = promote_scalar(bins)
|
||||
# bins => [3, 3, 5, ...]
|
||||
|
||||
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
||||
|
||||
def sparse_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
selected_experts, weights = select_experts(gate_logits, self.top_k)
|
||||
|
||||
(
|
||||
indices,
|
||||
bin_ids,
|
||||
bins,
|
||||
padded_bins,
|
||||
_,
|
||||
) = self.indices_and_padded_bins(selected_experts)
|
||||
|
||||
# Permute tokens and pad to prepare expert computation
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
|
||||
|
||||
# Create the sparse matrix topology
|
||||
with torch.no_grad():
|
||||
topo = self.topology(x, padded_bins)
|
||||
|
||||
# Perform the expert computation
|
||||
# First Dense x Dense -> Sparse for w1 and w3,
|
||||
# (top_k * sequence_length + padding, ffn_dim * n_experts)
|
||||
x = stk.Matrix(
|
||||
topo.size(),
|
||||
self.act(stk.ops.sdd(x, self.w1.t(), topo).data)
|
||||
* stk.ops.sdd(x, self.w3.t(), topo).data,
|
||||
topo.row_indices,
|
||||
topo.column_indices,
|
||||
topo.offsets,
|
||||
topo.column_indices_t,
|
||||
topo.offsets_t,
|
||||
topo.block_offsets_t,
|
||||
)
|
||||
|
||||
# Then Sparse x Dense -> Dense for w2
|
||||
# (top_k * sequence_length + padding, model_dim)
|
||||
x = stk.ops.dsd(x, self.w2)
|
||||
|
||||
# Permute back and remove padding
|
||||
# (sequence_length, model_dim)
|
||||
x = ops.padded_scatter(
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
x,
|
||||
indices,
|
||||
bin_ids,
|
||||
weights,
|
||||
bins,
|
||||
padded_bins,
|
||||
self.w13,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
self.quantize_scatter_num_bits,
|
||||
).view(*input_shape)
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(x, group=self.process_group)
|
||||
|
||||
return x.view(*input_shape)
|
||||
|
||||
def dense_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (sequence_length, model_dim)
|
||||
gate_logits: (sequence_length, n_experts)
|
||||
"""
|
||||
# optional reshape
|
||||
input_shape = x.shape
|
||||
x = x.view(-1, input_shape[-1])
|
||||
|
||||
# gate_logits: (sequence_length, n_experts)
|
||||
gate_logits = self.gate(x)
|
||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
||||
|
||||
if self.top_k < self.num_experts:
|
||||
_, not_selected_experts = torch.topk(
|
||||
all_probs,
|
||||
self.num_experts - self.top_k,
|
||||
largest=False,
|
||||
sorted=False,
|
||||
dim=1,
|
||||
)
|
||||
# Mask not selected experts
|
||||
all_probs.scatter_(1, not_selected_experts, 0)
|
||||
|
||||
# Re-normalize
|
||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
||||
weights = weights.to(x.dtype)
|
||||
|
||||
# Expand to [num_experts, sequence_length, model_dim]
|
||||
x = x.view(1, -1, input_shape[-1]).expand(self.num_experts, -1, input_shape[-1])
|
||||
|
||||
# Permute to [num_experts, model_dim, ffn_dim]
|
||||
w1 = self.w1.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
)
|
||||
w3 = self.w3.view(self.num_experts, self.ffn_dim, self.hidden_dim).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
|
||||
inter = self.act(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||
|
||||
out = torch.bmm(
|
||||
inter, self.w2.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
||||
)
|
||||
# Mask not selected experts
|
||||
out *= weights.t().view(self.num_experts, -1, 1)
|
||||
|
||||
# Sum experts
|
||||
out = out.sum(0)
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(x) > 256 and HAS_MEGABLOCKS:
|
||||
return self.sparse_forward(x)
|
||||
# This is faster when there is not a lot of tokens
|
||||
return self.dense_forward(x)
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DenseMoE(nn.Module):
|
||||
|
@ -679,9 +467,9 @@ class DenseMoE(nn.Module):
|
|||
|
||||
|
||||
class MixtralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
|
||||
self.self_attn = MixtralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
|
@ -740,16 +528,20 @@ class MixtralLayer(nn.Module):
|
|||
|
||||
|
||||
class MixtralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralLayer(
|
||||
"model" if not prefix else f"{prefix}.model",
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
|
@ -758,7 +550,9 @@ class MixtralModel(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
|
@ -808,13 +602,13 @@ class MixtralModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MixtralModel(config, weights)
|
||||
self.model = MixtralModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
|
|
|
@ -0,0 +1,302 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def load_vision_model(prefix, config, weights):
|
||||
if config.model_type == "clip_vision_model":
|
||||
from text_generation_server.models.custom_modeling.clip import (
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
|
||||
return CLIPVisionTransformer(
|
||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
def load_text_model(prefix, config, weights):
|
||||
if config.model_type == "llama":
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||
elif config.model_type == "mistral":
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
|
||||
return FlashMistralForCausalLM(prefix, config, weights)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.use_medusa = config.use_medusa
|
||||
self.language_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
|
@ -28,12 +28,17 @@ from text_generation_server.models.cache_manager import (
|
|||
BLOCK_SIZE,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.globals import MEM_POOL, ENABLE_CUDA_GRAPHS
|
||||
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
|
@ -106,6 +111,19 @@ class FlashCausalLMBatch(Batch):
|
|||
max_tokens=self.blocks * BLOCK_SIZE,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer):
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
return batch_tokenized_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
|
@ -114,16 +132,7 @@ class FlashCausalLMBatch(Batch):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
position_ids = []
|
||||
speculative_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
|
@ -165,6 +174,11 @@ class FlashCausalLMBatch(Batch):
|
|||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
if (
|
||||
tokenized_input[0] == tokenizer.bos_token_id
|
||||
and tokenized_input[1] == tokenizer.bos_token_id
|
||||
):
|
||||
tokenized_input = tokenized_input[1:]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
@ -690,7 +704,7 @@ class FlashCausalLM(Model):
|
|||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
|
@ -779,14 +793,16 @@ class FlashCausalLM(Model):
|
|||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
|
||||
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
|
||||
total_gpu_memory = torch.cuda.get_device_properties(
|
||||
self.device
|
||||
).total_memory
|
||||
|
||||
free_memory = max(
|
||||
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
|
||||
)
|
||||
elif IS_XPU_SYSTEM:
|
||||
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
|
||||
free_memory = int(total_gpu_memory *0.5)
|
||||
free_memory = int(total_gpu_memory * 0.5)
|
||||
else:
|
||||
raise NotImplementedError("FlashModel is only available on GPU")
|
||||
|
||||
|
@ -810,14 +826,14 @@ class FlashCausalLM(Model):
|
|||
self.device,
|
||||
)
|
||||
|
||||
if ENABLE_CUDA_GRAPHS:
|
||||
if CUDA_GRAPHS:
|
||||
try:
|
||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
# Warmup cuda graphs
|
||||
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
|
||||
for bs in CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate + 1 <= bs:
|
||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||
except Exception:
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
@ -877,22 +893,14 @@ class FlashCausalLM(Model):
|
|||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if (
|
||||
cu_seqlen_prefill is not None
|
||||
or cuda_graph is None
|
||||
or batch.speculative_ids is not None
|
||||
):
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
|
|
@ -3,12 +3,11 @@ import torch.distributed
|
|||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
CohereConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
|
@ -32,11 +31,11 @@ class FlashCohere(FlashCausalLM):
|
|||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
|
@ -46,7 +45,7 @@ class FlashCohere(FlashCausalLM):
|
|||
from_slow=False,
|
||||
)
|
||||
|
||||
config = CohereConfig.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
|
|
@ -21,6 +21,7 @@ tracer = trace.get_tracer(__name__)
|
|||
|
||||
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -71,7 +72,8 @@ class FlashLlama(FlashCausalLM):
|
|||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
prefix = ""
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
|
|
|
@ -6,8 +6,7 @@ import numpy as np
|
|||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
@ -66,19 +65,21 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
|
||||
@classmethod
|
||||
def from_tokenized(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
batch_tokenized_inputs,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
sliding_window, sliding_window_blocks = get_sliding_windows()
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
|
@ -302,14 +303,15 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
|||
class BaseFlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config_cls,
|
||||
model_cls,
|
||||
model_id: str,
|
||||
config_cls=AutoConfig,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
|
@ -321,22 +323,13 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
else:
|
||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||
|
||||
try:
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
except Exception:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = config_cls.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
|
@ -345,10 +338,12 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
config.use_medusa = use_medusa
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
if getattr(config, "sliding_window", None) is not None:
|
||||
set_sliding_window(
|
||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
)
|
||||
else:
|
||||
config.sliding_window = None
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -357,17 +352,19 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_cls(config, weights)
|
||||
prefix = ""
|
||||
model = model_cls(prefix, config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
||||
super().__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
|
@ -375,6 +372,16 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
sliding_window=config.sliding_window,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> int:
|
||||
return self.model.max_past
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
|
@ -382,7 +389,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
block_tables = (
|
||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
||||
|
@ -489,11 +496,11 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.model.max_past is not None:
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.model.max_past, max_s)
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
|
|
|
@ -3,4 +3,12 @@ import os
|
|||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||
)
|
||||
CUDA_GRAPHS = cuda_graphs if torch.cuda.is_available() else None
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
@ -20,29 +21,13 @@ from text_generation_server.models.types import (
|
|||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models.vlm_causal_lm import split
|
||||
|
||||
import re
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
def split(string):
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append(string[cursor:start])
|
||||
|
||||
parts.append(pattern.group(1))
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append(string[cursor:])
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
|
@ -93,10 +78,21 @@ class IdeficsCausalLMBatch(Batch):
|
|||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor: ProcessorMixin, # Hack
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "IdeficsCausalLMBatch":
|
||||
|
@ -127,10 +123,14 @@ class IdeficsCausalLMBatch(Batch):
|
|||
padding_right_offset, stopping_criteria.max_new_tokens
|
||||
)
|
||||
|
||||
# TODO Check impact on idefics
|
||||
prompts = []
|
||||
for inp in inputs:
|
||||
# Each input is encoded into a list, where each element of this input list is either a string or a URL
|
||||
prompts.append(split(inp))
|
||||
prompt = []
|
||||
for chunk in split(inp):
|
||||
prompt.append(chunk["content"])
|
||||
prompts.append(prompt)
|
||||
|
||||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
|
@ -141,7 +141,8 @@ class IdeficsCausalLMBatch(Batch):
|
|||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_truncation,
|
||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
# TODO Check impact on idefics
|
||||
# add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||
).to(device)
|
||||
for _ in pb.requests:
|
||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||
|
@ -156,7 +157,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
max_input_length = input_lengths.max()
|
||||
|
||||
input_ids = tokenized_inputs["input_ids"]
|
||||
pixel_values = tokenized_inputs["pixel_values"]
|
||||
pixel_values = tokenized_inputs.get("pixel_values", None)
|
||||
image_hidden_states = None
|
||||
# Allocate maximum attention_mask
|
||||
attention_mask = input_ids.new_zeros(
|
||||
|
@ -165,16 +166,19 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||
# Do the same for image_attention_mask
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
tokenized_inputs["pixel_values"].size(1),
|
||||
if pixel_values is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
max_input_length + padding_right_offset,
|
||||
pixel_values.size(1),
|
||||
)
|
||||
)
|
||||
)
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
|
@ -677,19 +681,22 @@ class IdeficsCausalLM(Model):
|
|||
start = time.time_ns()
|
||||
# slice the attention mask to the correct shape
|
||||
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, -(batch.padding_right_offset + 1)
|
||||
].unsqueeze(1)
|
||||
if batch.image_attention_mask is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
if batch.input_ids.size(1) == 1:
|
||||
# THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),
|
||||
# but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension
|
||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, -(batch.padding_right_offset + 1)
|
||||
].unsqueeze(1)
|
||||
else:
|
||||
image_attention_mask = batch.image_attention_mask[
|
||||
:, : -batch.padding_right_offset
|
||||
]
|
||||
|
||||
logits, speculative_logits, past, image_hidden_states = self.forward(
|
||||
input_ids=batch.input_ids,
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class LlavaNext(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
|
@ -13,7 +13,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.models.globals import ENABLE_CUDA_GRAPHS, MEM_POOL
|
||||
from text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL
|
||||
import time
|
||||
from text_generation_server.models.custom_modeling.mamba_modeling import (
|
||||
MambaModel,
|
||||
|
@ -465,12 +465,12 @@ class Mamba(Model):
|
|||
|
||||
def warmup(self, batch) -> Optional[int]:
|
||||
# TODO: implement warmup for Mamba if needed
|
||||
if ENABLE_CUDA_GRAPHS:
|
||||
if CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate == 0:
|
||||
try:
|
||||
logger.info("Experimental support for Cuda Graphs is enabled")
|
||||
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
|
||||
# Warmup cuda graphs
|
||||
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
|
||||
for bs in CUDA_GRAPHS:
|
||||
self.cuda_graph_warmup(bs)
|
||||
except Exception:
|
||||
logger.exception(f"Decode cuda graph warmup failed")
|
||||
|
|
|
@ -0,0 +1,329 @@
|
|||
import re
|
||||
import torch
|
||||
import math
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
FlashMistralBatch,
|
||||
)
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||
|
||||
|
||||
def split(string) -> List[Dict[str, str]]:
|
||||
parts = []
|
||||
cursor = 0
|
||||
for pattern in IMAGES.finditer(string):
|
||||
start = pattern.start()
|
||||
if start != cursor:
|
||||
parts.append({"type": "text", "content": string[cursor:start]})
|
||||
|
||||
parts.append({"type": "image", "content": pattern.group(1)})
|
||||
cursor = pattern.end()
|
||||
|
||||
if cursor != len(string):
|
||||
parts.append({"type": "text", "content": string[cursor:]})
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
patch_size (`int`):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def get_number_of_features(height: int, width: int, config) -> int:
|
||||
# From config
|
||||
# Hardcoded for CLIP for now
|
||||
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
|
||||
image_grid_pinpoints = config.image_grid_pinpoints
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
|
||||
assert image_size % patch_size == 0
|
||||
|
||||
npatches = image_size // patch_size
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
[height, width],
|
||||
image_grid_pinpoints,
|
||||
image_size,
|
||||
)
|
||||
|
||||
height_of_patch = math.ceil(height / width * npatches)
|
||||
|
||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||||
# They are only added after width
|
||||
newline_features = height_of_patch * num_patch_width
|
||||
# The base patch covers the entire image
|
||||
base_features = npatches**2
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
||||
|
||||
def load_data_uri(image_uri: str) -> Image.Image:
|
||||
image_uri = image_uri.split(",")[-1]
|
||||
content = base64.b64decode(image_uri)
|
||||
image = Image.open(BytesIO(content))
|
||||
return image
|
||||
|
||||
|
||||
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
||||
# assert get_number_of_features(640, 640) == 2928
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashMistralBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(cls, requests, tokenizer, processor, config):
|
||||
batch_inputs = []
|
||||
image_inputs = []
|
||||
max_truncation = 0
|
||||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
elif chunk["type"] == "image":
|
||||
image = chunk["content"]
|
||||
# Should never receive URLs anymore, processing should be done
|
||||
# On the rust layer.
|
||||
# This avoid making n queries per TP
|
||||
# if image.startswith("https://") or image.startswith("http://"):
|
||||
# image = processor.image_processor.fetch_images(image)
|
||||
if image.startswith("data:"):
|
||||
image = load_data_uri(image)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
height, width = image_input["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
full_text += "<image>" * num_features
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
|
||||
batch_inputs.append(full_text)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(
|
||||
[img["pixel_values"] for img in image_inputs], dim=0
|
||||
),
|
||||
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||||
}
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pb_processor(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
processor,
|
||||
config,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "VlmCausalLMBatch":
|
||||
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
|
||||
pb.requests, tokenizer, processor, config
|
||||
)
|
||||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
||||
class VlmCausalLM(BaseFlashMistral):
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.language_model.model.layers),
|
||||
model.language_model.model.num_key_value_heads,
|
||||
model.language_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.language_model, "max_past", None)
|
||||
|
||||
def forward(
|
||||
self, batch: VlmCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
speculative_ids = batch.speculative_ids
|
||||
|
||||
B, speculative_length = speculative_ids.shape
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat(
|
||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
||||
).reshape(-1)
|
||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
||||
arange_int = arange.to(dtype=torch.int32)
|
||||
new_position_ids = (
|
||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
||||
).view(-1)
|
||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
||||
input_lengths = (
|
||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||
).view(-1)
|
||||
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(1)
|
||||
.expand(B, new_length, -1)
|
||||
.reshape(B * new_length, -1)
|
||||
.contiguous()
|
||||
)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
max_s = batch.max_seqlen
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
image_sizes=batch.image_sizes,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(-1)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
|
@ -13,6 +13,7 @@ from typing import List, Optional
|
|||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
@ -78,13 +79,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
|
@ -100,13 +103,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
|
||||
async def Prefill(self, request, context):
|
||||
start = time.time_ns()
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb(
|
||||
if self.model.batch_type in {
|
||||
IdeficsCausalLMBatch,
|
||||
VlmCausalLMBatch,
|
||||
}: # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
|
|
|
@ -68,7 +68,7 @@ def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
|
|||
Forcing us to check for potentially different keys during load when looking
|
||||
for specific tensors (making tensor sharing explicit).
|
||||
"""
|
||||
loaded = torch.load(pt_file, map_location="cpu")
|
||||
loaded = torch.load(pt_file, map_location="cpu", weights_only=True)
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
|
||||
|
|
|
@ -122,6 +122,9 @@ def attention(
|
|||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
|
|
|
@ -18,12 +18,12 @@ except ImportError:
|
|||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
IS_CUDA_SYSTEM,
|
||||
IS_ROCM_SYSTEM,
|
||||
IS_XPU_SYSTEM,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
@ -42,12 +42,6 @@ except Exception:
|
|||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||
# V2 = False
|
||||
# log_once(
|
||||
# logger.warning,
|
||||
# "Disabling exllama v2 and using v1 instead because there are issues when sharding",
|
||||
# )
|
||||
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
|
@ -181,6 +175,8 @@ class EETQLinear(nn.Module):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
if weight.dtype != torch.float16:
|
||||
weight = weight.to(dtype=torch.float16)
|
||||
weight = torch.t(weight).contiguous().cpu()
|
||||
weight, scale = quant_weights(weight, torch.int8, False)
|
||||
|
||||
|
@ -194,6 +190,48 @@ class EETQLinear(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||
device = weight.device
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
|
||||
|
||||
class Fp8Linear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = weight.dtype
|
||||
self.qweight, self.scale = fp8_quantize(weight)
|
||||
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
qinput, scale = fp8_quantize(input)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
self.qweight.t(),
|
||||
out_dtype=self.dtype,
|
||||
scale_a=scale,
|
||||
scale_b=self.scale,
|
||||
bias=self.bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -305,6 +343,8 @@ def get_linear(weight, bias, quantize):
|
|||
raise ImportError(
|
||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
elif quantize == "fp8":
|
||||
linear = Fp8Linear(weight, bias)
|
||||
elif quantize == "bitsandbytes":
|
||||
warn_deprecate_bnb()
|
||||
linear = Linear8bitLt(
|
||||
|
@ -400,12 +440,12 @@ class ResBlock(torch.nn.Module):
|
|||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, config, medusa_config, weights):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
||||
for i in range(config["medusa_num_heads"])
|
||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||
for i in range(medusa_config["medusa_num_heads"])
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -415,12 +455,12 @@ class MedusaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, medusa_config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||
for i in range(config["medusa_num_layers"])
|
||||
for i in range(medusa_config["medusa_num_layers"])
|
||||
]
|
||||
)
|
||||
n = len(self.blocks)
|
||||
|
@ -435,7 +475,7 @@ class MedusaHead(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
class MedusaHeadV1(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.lm_head = lm_head
|
||||
|
@ -443,38 +483,156 @@ class SpeculativeHead(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, medusa_config, weights)
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
return MedusaHeadV1(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
if input.shape[0] > 128:
|
||||
return logits, None
|
||||
|
||||
speculative_logits = self.medusa(input)
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class MedusaHeadV2(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing and routing[k] != filename:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
|
||||
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
||||
|
||||
assert medusa_config["medusa_num_layers"] == 1
|
||||
self.linear = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.process_group = weights.process_group
|
||||
self.world_size = self.process_group.size()
|
||||
self.rank = self.process_group.rank()
|
||||
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
|
||||
def forward(self, x):
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
if x.shape[0] > 128:
|
||||
logits = self.lm_head(x)
|
||||
return logits, None
|
||||
|
||||
size = x.shape[-1]
|
||||
block_size = (size + self.world_size - 1) // self.world_size
|
||||
start = self.rank * block_size
|
||||
stop = (self.rank + 1) * block_size
|
||||
|
||||
x_block = x[:, start:stop]
|
||||
|
||||
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||
medusa_res = self.act(self.linear(x)).reshape(
|
||||
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||
)
|
||||
|
||||
# Apply all residual medusa heads
|
||||
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||
|
||||
# Gather medusa heads
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
|
||||
# Stack x and medusa residual x
|
||||
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||
|
||||
# Compute lm head on x + medusa residual x
|
||||
logits = self.lm_head(stacked_x)
|
||||
|
||||
# Finally, split logits from speculative logits
|
||||
logits, speculative_logits = torch.split(
|
||||
logits, [1, self.n_medusa_heads], dim=-2
|
||||
)
|
||||
# Squeeze added dimension
|
||||
logits = logits.squeeze(-2)
|
||||
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
self.head = lm_head
|
||||
self.medusa = medusa
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
use_medusa = config.use_medusa
|
||||
if use_medusa:
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
weights.routing[k] = filename
|
||||
|
||||
medusa = MedusaModel(config, weights)
|
||||
lm_head = None
|
||||
try:
|
||||
medusa = MedusaHeadV1.load(config, prefix, weights)
|
||||
except:
|
||||
medusa = MedusaHeadV2(config, prefix, weights)
|
||||
else:
|
||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||
return logits, speculative_logits
|
||||
if self.medusa is not None:
|
||||
return self.medusa(input)
|
||||
|
||||
assert self.head is not None
|
||||
logits = self.head(input)
|
||||
return logits, None
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
|
|
|
@ -6,8 +6,7 @@ from text_generation_server.utils.import_utils import (
|
|||
)
|
||||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
from vllm import cache_ops
|
||||
from vllm import attention_ops
|
||||
from vllm._C import cache_ops, ops
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
@ -22,8 +21,11 @@ def reshape_and_cache(
|
|||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
|
||||
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
elif IS_XPU_SYSTEM:
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
|
@ -67,6 +69,7 @@ def attention(
|
|||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
|
||||
if IS_XPU_SYSTEM:
|
||||
query = query.contiguous()
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
|
@ -83,9 +86,10 @@ def attention(
|
|||
None,
|
||||
)
|
||||
|
||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
|
||||
if use_v1:
|
||||
attention_ops.paged_attention_v1(
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
|
@ -97,6 +101,8 @@ def attention(
|
|||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
|
@ -112,7 +118,7 @@ def attention(
|
|||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
|
@ -127,4 +133,6 @@ def attention(
|
|||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
ldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'
|
||||
|
||||
text-generation-launcher $@
|
Loading…
Reference in New Issue