euff
This commit is contained in:
commit
e6d07a6d34
|
@ -45,7 +45,7 @@ jobs:
|
||||||
export dockerfile="Dockerfile"
|
export dockerfile="Dockerfile"
|
||||||
export label_extension=""
|
export label_extension=""
|
||||||
export docker_devices=""
|
export docker_devices=""
|
||||||
export runs_on="aws-g6-12xlarge-plus-priv"
|
export runs_on="aws-g6-12xl-plus-priv-cache"
|
||||||
export platform=""
|
export platform=""
|
||||||
;;
|
;;
|
||||||
rocm)
|
rocm)
|
||||||
|
|
|
@ -38,4 +38,4 @@ jobs:
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
- name: Rust tests.
|
- name: Rust tests.
|
||||||
run: nix develop .#test --command cargo test
|
run: nix build .#checks.$(nix eval --impure --raw --expr 'builtins.currentSystem').rust -L
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -258,7 +258,7 @@ COPY server/Makefile server/Makefile
|
||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements_cuda.txt && \
|
pip install -r requirements_cuda.txt && \
|
||||||
pip install ".[bnb, accelerate, marlin, quantize, peft, outlines]" --no-cache-dir && \
|
pip install ".[bnb, accelerate, marlin, moe, quantize, peft, outlines]" --no-cache-dir && \
|
||||||
pip install nvidia-nccl-cu12==2.22.3
|
pip install nvidia-nccl-cu12==2.22.3
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||||
|
|
|
@ -44,9 +44,35 @@ RUN cargo build --profile release-opt
|
||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
|
|
||||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu
|
FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
|
|
||||||
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
|
ARG PYTHON_VERSION='3.11.10'
|
||||||
|
# Automatically set by buildx
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
# Install mamba
|
||||||
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
|
esac && \
|
||||||
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") exit 1 ;; \
|
||||||
|
*) /opt/conda/bin/conda update -y conda && \
|
||||||
|
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||||
|
esac && \
|
||||||
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
|
||||||
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
@ -56,7 +82,7 @@ RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dea
|
||||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
|
||||||
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
|
||||||
|
|
||||||
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
ENV HF_HOME=/data \
|
ENV HF_HOME=/data \
|
||||||
|
@ -65,9 +91,7 @@ ENV HF_HOME=/data \
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
|
||||||
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
|
@ -82,14 +106,12 @@ ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
|
||||||
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
|
||||||
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
|
||||||
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
|
||||||
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
|
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:/opt/conda/lib
|
||||||
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
ENV PATH=/opt/conda/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
|
||||||
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
|
||||||
|
|
||||||
RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch
|
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
|
@ -133,12 +155,19 @@ RUN chmod +x ~/mambaforge.sh && \
|
||||||
bash ~/mambaforge.sh -b -p /opt/conda && \
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
rm ~/mambaforge.sh
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") exit 1 ;; \
|
||||||
|
*) /opt/conda/bin/conda update -y conda && \
|
||||||
|
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||||
|
esac && \
|
||||||
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
RUN conda install -c conda-forge gperftools mkl
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
RUN pip install triton numa
|
RUN pip install triton py-libnuma
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
@ -151,10 +180,10 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update
|
||||||
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||||
|
|
||||||
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so
|
||||||
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
ENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||||
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
ENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch
|
||||||
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||||
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib
|
||||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib/"
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
|
|
|
@ -364,7 +364,7 @@ impl State {
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
Some(block_allocation) => {
|
Some(block_allocation) => {
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
@ -436,6 +436,12 @@ impl State {
|
||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Empty batch
|
||||||
|
if batch_requests.is_empty() {
|
||||||
|
tracing::debug!("Filterered out all entries");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
|
|
|
@ -1,10 +1,22 @@
|
||||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||||
use slotmap::{DefaultKey, SlotMap};
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeSet, HashMap},
|
collections::{BTreeSet, HashMap},
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
fn hash(slice: &[u32]) -> u64 {
|
||||||
|
assert!(!slice.is_empty());
|
||||||
|
if slice.len() == 1 {
|
||||||
|
slice[0] as u64
|
||||||
|
} else {
|
||||||
|
let mut s = std::hash::DefaultHasher::new();
|
||||||
|
slice.hash(&mut s);
|
||||||
|
s.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct RadixAllocator {
|
pub struct RadixAllocator {
|
||||||
allocation_id: u64,
|
allocation_id: u64,
|
||||||
|
|
||||||
|
@ -44,6 +56,10 @@ impl RadixAllocator {
|
||||||
// the free list if we cannot allocate enough blocks. This is only
|
// the free list if we cannot allocate enough blocks. This is only
|
||||||
// temporary, the trie needs to be able to report whether it can
|
// temporary, the trie needs to be able to report whether it can
|
||||||
// allocate the requested amount. Just not implemented yet.
|
// allocate the requested amount. Just not implemented yet.
|
||||||
|
tracing::debug!(
|
||||||
|
"Free blocks {} need {n_blocks_needed}",
|
||||||
|
self.free_blocks.len()
|
||||||
|
);
|
||||||
self.free_blocks.extend(
|
self.free_blocks.extend(
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
.evict(n_blocks_needed - self.free_blocks.len()),
|
.evict(n_blocks_needed - self.free_blocks.len()),
|
||||||
|
@ -94,6 +110,9 @@ impl Allocator for RadixAllocator {
|
||||||
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
match self.alloc_or_reclaim(suffix_blocks as usize) {
|
||||||
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
Some(suffix_blocks) => blocks.extend(suffix_blocks),
|
||||||
None => {
|
None => {
|
||||||
|
tracing::debug!("Cannot allocate {:?}", self.cache_blocks);
|
||||||
|
tracing::debug!("Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens");
|
||||||
|
tracing::debug!("Block size {}", self.block_size);
|
||||||
self.cache_blocks
|
self.cache_blocks
|
||||||
.decref(prefix_node)
|
.decref(prefix_node)
|
||||||
.expect("Failed to decrement refcount");
|
.expect("Failed to decrement refcount");
|
||||||
|
@ -211,7 +230,6 @@ struct RadixAllocation {
|
||||||
pub enum TrieError {
|
pub enum TrieError {
|
||||||
InvalidNodeId,
|
InvalidNodeId,
|
||||||
RefCountUnderflow,
|
RefCountUnderflow,
|
||||||
BlockTokenCountMismatch,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type NodeId = DefaultKey;
|
pub type NodeId = DefaultKey;
|
||||||
|
@ -268,16 +286,19 @@ impl RadixTrie {
|
||||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
let node = &self.nodes[node_id];
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
if let Some(&child_id) = node.children.get(&key[0]) {
|
if key.len() >= self.block_size {
|
||||||
self.update_access_time(child_id);
|
let node_key = hash(&key[..self.block_size]);
|
||||||
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
if let Some(&child_id) = node.children.get(&node_key) {
|
||||||
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
self.update_access_time(child_id);
|
||||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
let child = self.nodes.get(child_id).expect("Invalid child identifier");
|
||||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);
|
||||||
|
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||||
|
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||||
|
|
||||||
let key = &key[shared_prefix_len..];
|
let key = &key[shared_prefix_len..];
|
||||||
if !key.is_empty() {
|
if !key.is_empty() {
|
||||||
node_id = self.find_(child_id, key, blocks);
|
node_id = self.find_(child_id, key, blocks);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -344,9 +365,11 @@ impl RadixTrie {
|
||||||
// evict n_blocks and return `None` if we can't. We are now needlessly
|
// evict n_blocks and return `None` if we can't. We are now needlessly
|
||||||
// evicting prefixes from the cache in such a case.
|
// evicting prefixes from the cache in such a case.
|
||||||
let mut evicted = Vec::new();
|
let mut evicted = Vec::new();
|
||||||
|
tracing::debug!("Evicting in search of {n_blocks}");
|
||||||
|
|
||||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||||
let blocks_needed = n_blocks - evicted.len();
|
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
||||||
|
tracing::debug!("Evicting node {node_id:?} ");
|
||||||
|
|
||||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -368,8 +391,11 @@ impl RadixTrie {
|
||||||
// the required number of blocks and leave the remaining blocks
|
// the required number of blocks and leave the remaining blocks
|
||||||
// untouched.
|
// untouched.
|
||||||
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
let node = self.nodes.get_mut(node_id).expect("Leave does not exist");
|
||||||
node.key.truncate(node.blocks.len() - blocks_needed);
|
|
||||||
evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed));
|
let truncate_blocks = node.blocks.len() - blocks_needed;
|
||||||
|
let truncate_tokens = truncate_blocks * self.block_size;
|
||||||
|
node.key.truncate(truncate_tokens);
|
||||||
|
evicted.extend(node.blocks.split_off(truncate_blocks));
|
||||||
self.leaves.insert((last_access, node_id));
|
self.leaves.insert((last_access, node_id));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -400,11 +426,10 @@ impl RadixTrie {
|
||||||
// the part of the prefix that is already in the trie to detect
|
// the part of the prefix that is already in the trie to detect
|
||||||
// mismatches.
|
// mismatches.
|
||||||
|
|
||||||
if tokens.len() != blocks.len() * self.block_size {
|
assert_eq!(tokens.len(), blocks.len() * self.block_size);
|
||||||
return Err(TrieError::BlockTokenCountMismatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) {
|
let node_key = hash(&tokens[..self.block_size]);
|
||||||
|
if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {
|
||||||
self.update_access_time(child_id);
|
self.update_access_time(child_id);
|
||||||
let child = self
|
let child = self
|
||||||
.nodes
|
.nodes
|
||||||
|
@ -452,14 +477,15 @@ impl RadixTrie {
|
||||||
.get_mut(node_id)
|
.get_mut(node_id)
|
||||||
.expect("Node to-be split does not exist");
|
.expect("Node to-be split does not exist");
|
||||||
let mut parent_key = node.key.split_off(prefix_len);
|
let mut parent_key = node.key.split_off(prefix_len);
|
||||||
let mut parent_blocks = node.blocks.split_off(prefix_len);
|
let prefix_blocks = prefix_len / self.block_size;
|
||||||
|
let mut parent_blocks = node.blocks.split_off(prefix_blocks);
|
||||||
|
|
||||||
// Move first part of the prefix to the parent. We swap to avoid
|
// Move first part of the prefix to the parent. We swap to avoid
|
||||||
// an allocation + copy for both splits of the key/blocks.
|
// an allocation + copy for both splits of the key/blocks.
|
||||||
std::mem::swap(&mut node.key, &mut parent_key);
|
std::mem::swap(&mut node.key, &mut parent_key);
|
||||||
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
std::mem::swap(&mut node.blocks, &mut parent_blocks);
|
||||||
|
|
||||||
let node_key = node.key[0];
|
let node_key = hash(&node.key[..self.block_size]);
|
||||||
|
|
||||||
let grandparent_id = node.parent.expect("Node does not have a parent");
|
let grandparent_id = node.parent.expect("Node does not have a parent");
|
||||||
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);
|
||||||
|
@ -484,7 +510,7 @@ impl RadixTrie {
|
||||||
) -> NodeId {
|
) -> NodeId {
|
||||||
let key = key.into();
|
let key = key.into();
|
||||||
let blocks = blocks.into();
|
let blocks = blocks.into();
|
||||||
let first = key[0];
|
let first = hash(&key[..self.block_size]);
|
||||||
|
|
||||||
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
let child = TrieNode::new(key, blocks, self.time, Some(parent_id));
|
||||||
let child_id = self.nodes.insert(child);
|
let child_id = self.nodes.insert(child);
|
||||||
|
@ -496,10 +522,10 @@ impl RadixTrie {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a node to the parent.
|
/// Add a node to the parent.
|
||||||
fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) {
|
fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {
|
||||||
// Unwrap here, passing in an unknown id is a programming error.
|
// Unwrap here, passing in an unknown id is a programming error.
|
||||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
if parent.children.insert(first, child_id).is_none() {
|
if parent.children.insert(hash, child_id).is_none() {
|
||||||
// Only increase reference count if child does not replace another child.
|
// Only increase reference count if child does not replace another child.
|
||||||
self.incref(parent_id)
|
self.incref(parent_id)
|
||||||
.expect("Failed to increase parent refcount");
|
.expect("Failed to increase parent refcount");
|
||||||
|
@ -517,7 +543,9 @@ impl RadixTrie {
|
||||||
);
|
);
|
||||||
let parent_id = node.parent.expect("Attempted to remove root node");
|
let parent_id = node.parent.expect("Attempted to remove root node");
|
||||||
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node");
|
||||||
parent.children.remove(&node.key[0]);
|
|
||||||
|
let node_key = hash(&node.key[..self.block_size]);
|
||||||
|
parent.children.remove(&node_key);
|
||||||
self.decref(parent_id)
|
self.decref(parent_id)
|
||||||
.expect("Failed to decrease parent refcount");
|
.expect("Failed to decrease parent refcount");
|
||||||
node
|
node
|
||||||
|
@ -571,7 +599,7 @@ impl RadixTrie {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct TrieNode {
|
struct TrieNode {
|
||||||
blocks: Vec<u32>,
|
blocks: Vec<u32>,
|
||||||
children: HashMap<u32, NodeId>,
|
children: HashMap<u64, NodeId>,
|
||||||
key: Vec<u32>,
|
key: Vec<u32>,
|
||||||
last_accessed: u64,
|
last_accessed: u64,
|
||||||
parent: Option<NodeId>,
|
parent: Option<NodeId>,
|
||||||
|
|
|
@ -16,7 +16,7 @@ path = "src/main.rs"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
average = "0.14"
|
average = "0.14"
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
crossterm = "0.27"
|
crossterm = "0.28.1"
|
||||||
float-ord = "0.3.2"
|
float-ord = "0.3.2"
|
||||||
serde = {version = "1.0.188", features = ["derive"]}
|
serde = {version = "1.0.188", features = ["derive"]}
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
@ -25,7 +25,7 @@ text-generation-client = { path = "../backends/client" }
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||||
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
|
ratatui = { version = "0.28.1", default-features = false, features = ["crossterm"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
||||||
use crate::generation::{Decode, Message, Prefill};
|
use crate::generation::{Decode, Message, Prefill};
|
||||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
use text_generation_client::ClientError;
|
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
|
||||||
use tokio::sync::mpsc;
|
use ratatui::style::{Color, Modifier, Style};
|
||||||
use tui::backend::Backend;
|
use ratatui::text::{Line, Span};
|
||||||
use tui::layout::{Alignment, Constraint, Direction, Layout};
|
use ratatui::widgets::{
|
||||||
use tui::style::{Color, Modifier, Style};
|
|
||||||
use tui::text::{Line, Span};
|
|
||||||
use tui::widgets::{
|
|
||||||
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
||||||
};
|
};
|
||||||
use tui::{symbols, Frame};
|
use ratatui::{symbols, Frame};
|
||||||
|
use text_generation_client::ClientError;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
/// TUI powered App
|
/// TUI powered App
|
||||||
pub(crate) struct App {
|
pub(crate) struct App {
|
||||||
|
@ -153,7 +152,7 @@ impl App {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Render frame
|
/// Render frame
|
||||||
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
|
pub fn render(&mut self, f: &mut Frame) {
|
||||||
let batch_progress =
|
let batch_progress =
|
||||||
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
|
(self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);
|
||||||
let run_progress =
|
let run_progress =
|
||||||
|
@ -172,7 +171,7 @@ impl App {
|
||||||
]
|
]
|
||||||
.as_ref(),
|
.as_ref(),
|
||||||
)
|
)
|
||||||
.split(f.size());
|
.split(f.area());
|
||||||
|
|
||||||
// Top row horizontal layout
|
// Top row horizontal layout
|
||||||
let top = Layout::default()
|
let top = Layout::default()
|
||||||
|
@ -239,7 +238,7 @@ impl App {
|
||||||
f.render_widget(helper, row5[0]);
|
f.render_widget(helper, row5[0]);
|
||||||
|
|
||||||
// Batch tabs
|
// Batch tabs
|
||||||
let titles = self
|
let titles: Vec<Line> = self
|
||||||
.data
|
.data
|
||||||
.batch_size
|
.batch_size
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
@ -7,12 +7,12 @@ mod utils;
|
||||||
use crate::app::App;
|
use crate::app::App;
|
||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crossterm::ExecutableCommand;
|
use crossterm::ExecutableCommand;
|
||||||
|
use ratatui::backend::CrosstermBackend;
|
||||||
|
use ratatui::Terminal;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tui::backend::CrosstermBackend;
|
|
||||||
use tui::Terminal;
|
|
||||||
|
|
||||||
/// Run benchmarking app
|
/// Run benchmarking app
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
|
12
flake.lock
12
flake.lock
|
@ -853,11 +853,11 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1726021481,
|
"lastModified": 1726280639,
|
||||||
"narHash": "sha256-4J4E+Fh+77XIYnq2RVtg+ENWXpu6t74P0jKN/f2RQmI=",
|
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "1c2c120246c51a644c20ba2a36a33d3bd4860d70",
|
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -978,11 +978,11 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1725950569,
|
"lastModified": 1726229792,
|
||||||
"narHash": "sha256-nJHA1SvIQbXySpL2ueNbzQOhnkQASa5tOLz/kdW0PWA=",
|
"narHash": "sha256-9xsLmjc9nr7a4PTddKv2DOi82ompTtJNyjO6R67y5tE=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"rev": "d40f3c22e9bcc5e16c94d4605cf6a7d74dd07f46",
|
"rev": "1a902f4818e94c3f8d95f6000db17bc3fadd0ce7",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
22
flake.nix
22
flake.nix
|
@ -69,7 +69,29 @@
|
||||||
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
|
checks = {
|
||||||
|
rust = with pkgs; rustPlatform.buildRustPackage {
|
||||||
|
name = "rust-checks";
|
||||||
|
src = ./.;
|
||||||
|
cargoLock = {
|
||||||
|
lockFile = ./Cargo.lock;
|
||||||
|
};
|
||||||
|
buildInputs = [ openssl.dev ];
|
||||||
|
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ];
|
||||||
|
buildPhase = ''
|
||||||
|
cargo check
|
||||||
|
'';
|
||||||
|
checkPhase = ''
|
||||||
|
cargo fmt -- --check
|
||||||
|
cargo test -j $NIX_BUILD_CORES
|
||||||
|
cargo clippy
|
||||||
|
'';
|
||||||
|
installPhase = "touch $out";
|
||||||
|
} ;
|
||||||
|
};
|
||||||
|
|
||||||
formatter = pkgs.nixfmt-rfc-style;
|
formatter = pkgs.nixfmt-rfc-style;
|
||||||
|
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
default = pure;
|
default = pure;
|
||||||
|
|
||||||
|
|
|
@ -342,6 +342,7 @@ def launcher(event_loop):
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
lora_adapters: Optional[List[str]] = None,
|
lora_adapters: Optional[List[str]] = None,
|
||||||
cuda_graphs: Optional[List[int]] = None,
|
cuda_graphs: Optional[List[int]] = None,
|
||||||
|
attention: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
master_port = random.randint(10_000, 20_000)
|
master_port = random.randint(10_000, 20_000)
|
||||||
|
@ -401,6 +402,8 @@ def launcher(event_loop):
|
||||||
|
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
if attention is not None:
|
||||||
|
env["ATTENTION"] = attention
|
||||||
|
|
||||||
with tempfile.TemporaryFile("w+") as tmp:
|
with tempfile.TemporaryFile("w+") as tmp:
|
||||||
# We'll output stdout/stderr to a temporary file. Using a pipe
|
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||||
|
@ -437,6 +440,7 @@ def launcher(event_loop):
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
lora_adapters: Optional[List[str]] = None,
|
lora_adapters: Optional[List[str]] = None,
|
||||||
cuda_graphs: Optional[List[int]] = None,
|
cuda_graphs: Optional[List[int]] = None,
|
||||||
|
attention: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
|
@ -491,6 +495,8 @@ def launcher(event_loop):
|
||||||
}
|
}
|
||||||
if not use_flash_attention:
|
if not use_flash_attention:
|
||||||
env["USE_FLASH_ATTENTION"] = "false"
|
env["USE_FLASH_ATTENTION"] = "false"
|
||||||
|
if attention is not None:
|
||||||
|
env["ATTENTION"] = attention
|
||||||
|
|
||||||
if HF_TOKEN is not None:
|
if HF_TOKEN is not None:
|
||||||
env["HF_TOKEN"] = HF_TOKEN
|
env["HF_TOKEN"] = HF_TOKEN
|
||||||
|
@ -522,6 +528,7 @@ def launcher(event_loop):
|
||||||
devices=devices,
|
devices=devices,
|
||||||
volumes=volumes,
|
volumes=volumes,
|
||||||
ports={"80/tcp": port},
|
ports={"80/tcp": port},
|
||||||
|
healthcheck={"timeout": int(10 * 1e9)},
|
||||||
shm_size="1G",
|
shm_size="1G",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,114 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1445312,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4648438,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.6005859,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39526367,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.640625,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18774414,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.96484375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.003168106,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16540527,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.08886719,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.75878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11242676,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.7939453,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17102051,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34326172,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -17.234375,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -7.4375,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.8046875,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.33032227,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 1313,
|
||||||
|
"logprob": -2.3613281,
|
||||||
|
"special": false,
|
||||||
|
"text": "It"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3969,
|
||||||
|
"logprob": -0.7285156,
|
||||||
|
"special": false,
|
||||||
|
"text": " seems"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -1.3466797,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 528,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " me"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28725,
|
||||||
|
"logprob": -1.6757812,
|
||||||
|
"special": false,
|
||||||
|
"text": ","
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 369,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 513,
|
||||||
|
"logprob": -1.1269531,
|
||||||
|
"special": false,
|
||||||
|
"text": " if"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 368,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " you"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28742,
|
||||||
|
"logprob": -2.4921875,
|
||||||
|
"special": false,
|
||||||
|
"text": "'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 267,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "re"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "What is gradient descent?\n\nIt seems to me, that if you're"
|
||||||
|
}
|
|
@ -0,0 +1,458 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1445312,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4648438,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.6005859,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39526367,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.640625,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18774414,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.96484375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.003168106,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16369629,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.0881958,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.76708984,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.57373047,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11291504,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.79589844,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.1694336,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34350586,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1445312,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4677734,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.6015625,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39453125,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.6435547,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18713379,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.9628906,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.0032176971,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16540527,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.08898926,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.765625,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5708008,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11401367,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.7963867,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17028809,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34326172,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.140625,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4658203,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6796875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.5898438,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.3955078,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.64501953,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18493652,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.9580078,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.0032176971,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16552734,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.08874512,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.75878906,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11236572,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.79541016,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17102051,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34326172,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1824,
|
||||||
|
"logprob": -6.1328125,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -1.4658203,
|
||||||
|
"text": "is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21135,
|
||||||
|
"logprob": -13.6796875,
|
||||||
|
"text": "gradient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -1.5947266,
|
||||||
|
"text": "descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28804,
|
||||||
|
"logprob": -0.39648438,
|
||||||
|
"text": "?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.6464844,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.18688965,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 20910,
|
||||||
|
"logprob": -0.9609375,
|
||||||
|
"special": false,
|
||||||
|
"text": "Grad"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 722,
|
||||||
|
"logprob": -0.003168106,
|
||||||
|
"special": false,
|
||||||
|
"text": "ient"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24871,
|
||||||
|
"logprob": -0.16601562,
|
||||||
|
"special": false,
|
||||||
|
"text": " descent"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 349,
|
||||||
|
"logprob": -0.088134766,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 396,
|
||||||
|
"logprob": -0.7597656,
|
||||||
|
"special": false,
|
||||||
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 18586,
|
||||||
|
"logprob": -0.5708008,
|
||||||
|
"special": false,
|
||||||
|
"text": " optimization"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 9464,
|
||||||
|
"logprob": -0.11291504,
|
||||||
|
"special": false,
|
||||||
|
"text": " algorithm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1307,
|
||||||
|
"logprob": -0.7944336,
|
||||||
|
"special": false,
|
||||||
|
"text": " used"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 298,
|
||||||
|
"logprob": -0.17102051,
|
||||||
|
"special": false,
|
||||||
|
"text": " to"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 26518,
|
||||||
|
"logprob": -0.34399414,
|
||||||
|
"special": false,
|
||||||
|
"text": " minimize"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
}
|
||||||
|
]
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,75 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_mixtral_handle(launcher):
|
||||||
|
with launcher("mistralai/Mixtral-8x7B-v0.1", num_shard=8) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_mixtral(flash_mixtral_handle):
|
||||||
|
await flash_mixtral_handle.health(300)
|
||||||
|
return flash_mixtral_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="requires > 4 shards")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral(flash_mixtral, response_snapshot):
|
||||||
|
response = await flash_mixtral.generate(
|
||||||
|
"What is gradient descent?\n\n", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="requires > 4 shards")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_all_params(flash_mixtral, response_snapshot):
|
||||||
|
response = await flash_mixtral.generate(
|
||||||
|
"What is gradient descent?\n\n",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "What is gradient descent?\n\nIt seems to me, that if you're"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="requires > 4 shards")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_load(flash_mixtral, generate_load, response_snapshot):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_mixtral, "What is gradient descent?\n\n", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert responses[0].details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== "Gradient descent is an optimization algorithm used to minimize"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -21,6 +21,7 @@
|
||||||
loguru,
|
loguru,
|
||||||
mamba-ssm,
|
mamba-ssm,
|
||||||
marlin-kernels,
|
marlin-kernels,
|
||||||
|
moe-kernels,
|
||||||
opentelemetry-api,
|
opentelemetry-api,
|
||||||
opentelemetry-exporter-otlp,
|
opentelemetry-exporter-otlp,
|
||||||
opentelemetry-instrumentation-grpc,
|
opentelemetry-instrumentation-grpc,
|
||||||
|
@ -88,6 +89,7 @@ buildPythonPackage {
|
||||||
loguru
|
loguru
|
||||||
mamba-ssm
|
mamba-ssm
|
||||||
marlin-kernels
|
marlin-kernels
|
||||||
|
moe-kernels
|
||||||
opentelemetry-api
|
opentelemetry-api
|
||||||
opentelemetry-exporter-otlp
|
opentelemetry-exporter-otlp
|
||||||
opentelemetry-instrumentation-grpc
|
opentelemetry-instrumentation-grpc
|
||||||
|
|
|
@ -2136,9 +2136,12 @@ async fn start(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||||
// .unwrap();
|
// .unwrap();
|
||||||
let prom_handle = builder
|
// See: https://github.com/metrics-rs/metrics/issues/467#issuecomment-2022755151
|
||||||
.install_recorder()
|
let (recorder, _) = builder
|
||||||
.expect("failed to install metrics recorder");
|
.build()
|
||||||
|
.expect("failed to build prometheus recorder");
|
||||||
|
let prom_handle = recorder.handle();
|
||||||
|
metrics::set_global_recorder(recorder).expect("Failed to set global recorder");
|
||||||
|
|
||||||
// Metrics descriptions
|
// Metrics descriptions
|
||||||
metrics::describe_counter!("tgi_request_success", "Number of successful requests");
|
metrics::describe_counter!("tgi_request_success", "Number of successful requests");
|
||||||
|
|
|
@ -1242,6 +1242,82 @@ files = [
|
||||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:d268d818932ddcbca9bc71021dc63b008aae832827a7c0484cf206bd59cfc9ab"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:614bbc3f41b707b0c40372f0bb00e218ad0842d306f90bef28ce8e98e7fcb7cb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:c2f48ed541353be03157d4015270dff797f7b7b8a664babdcbdf7414867d5abd"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "moe-kernels"
|
||||||
|
version = "0.2.2"
|
||||||
|
description = "MoE kernels"
|
||||||
|
optional = true
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d5f0339b73426c422872f7ff060433df6cd8e881451baf85ee7454e0e905f9d8"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
nvidia-ml-py = "*"
|
||||||
|
torch = "*"
|
||||||
|
triton = "*"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "url"
|
||||||
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
@ -1600,6 +1676,17 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
nvidia-nvjitlink-cu12 = "*"
|
nvidia-nvjitlink-cu12 = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nvidia-ml-py"
|
||||||
|
version = "12.560.30"
|
||||||
|
description = "Python Bindings for the NVIDIA Management Library"
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97"},
|
||||||
|
{file = "nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nvidia-nccl-cu12"
|
name = "nvidia-nccl-cu12"
|
||||||
version = "2.20.5"
|
version = "2.20.5"
|
||||||
|
@ -3638,6 +3725,7 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
marlin = ["marlin-kernels", "marlin-kernels", "marlin-kernels", "marlin-kernels"]
|
||||||
|
moe = ["moe-kernels", "moe-kernels", "moe-kernels", "moe-kernels"]
|
||||||
outlines = ["outlines"]
|
outlines = ["outlines"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
quantize = ["accelerate", "datasets", "texttable"]
|
quantize = ["accelerate", "datasets", "texttable"]
|
||||||
|
|
|
@ -46,6 +46,12 @@ marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
|
moe-kernels = [
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
|
]
|
||||||
rich = "^13.7.1"
|
rich = "^13.7.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
@ -53,6 +59,7 @@ torch = ["torch"]
|
||||||
accelerate = ["accelerate"]
|
accelerate = ["accelerate"]
|
||||||
bnb = ["bitsandbytes"]
|
bnb = ["bitsandbytes"]
|
||||||
marlin = ["marlin-kernels"]
|
marlin = ["marlin-kernels"]
|
||||||
|
moe = ["moe-kernels"]
|
||||||
peft = ["peft"]
|
peft = ["peft"]
|
||||||
quantize = ["texttable", "datasets", "accelerate"]
|
quantize = ["texttable", "datasets", "accelerate"]
|
||||||
outlines = ["outlines"]
|
outlines = ["outlines"]
|
||||||
|
|
|
@ -22,9 +22,9 @@ def attention(
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
q,
|
q.contiguous() if q.device.type == "xpu" else q,
|
||||||
key_cache,
|
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
||||||
value_cache,
|
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.utils.weights import (
|
||||||
|
DefaultWeightsLoader,
|
||||||
|
UnquantizedWeight,
|
||||||
|
Weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that uses fused kernels to only apply the active experts
|
||||||
|
for each token (rather than applying all experts and selecting the
|
||||||
|
outputs of active experts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
|
cls = UnquantizedSparseMoELayer
|
||||||
|
# Once we wire up GPTQ-Marlin MoE:
|
||||||
|
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||||
|
# cls = GPTQMarlinSparseMoELayer
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.moe = cls(
|
||||||
|
n_expert_group=n_expert_group,
|
||||||
|
n_experts=n_experts,
|
||||||
|
prefix=prefix,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk=topk,
|
||||||
|
topk_group=topk_group,
|
||||||
|
weights=weights,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
down_proj_name=down_proj_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.moe(x, gating_output=gating_output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_supported(weights: Weights) -> bool:
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
isinstance(weights.loader, DefaultWeightsLoader)
|
||||||
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
|
)
|
||||||
|
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
|
# Once we wire up GPTQ-Marlin MoE:
|
||||||
|
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||||
|
)
|
|
@ -0,0 +1,125 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||||
|
|
||||||
|
if SYSTEM != "ipex":
|
||||||
|
from moe_kernels.fused_moe import fused_moe
|
||||||
|
|
||||||
|
|
||||||
|
class UnquantizedSparseMoELayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
|
||||||
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
gate_proj_name=gate_proj_name,
|
||||||
|
up_proj_name=up_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = _load_expert_weights_row(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
name=down_proj_name,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj,
|
||||||
|
w2=self.down_proj,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
gate_proj_name: str,
|
||||||
|
up_proj_name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
all_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_weights_row(
|
||||||
|
f"{prefix}.{i}.{name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(weight, UnquantizedWeight)
|
||||||
|
|
||||||
|
if all_weight is None:
|
||||||
|
all_weight = torch.empty(
|
||||||
|
(n_experts,) + weight.weight.shape,
|
||||||
|
dtype=weight.weight.dtype,
|
||||||
|
device=weight.weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weight[i] = weight.weight
|
||||||
|
|
||||||
|
assert all_weight is not None
|
||||||
|
|
||||||
|
return all_weight
|
|
@ -13,9 +13,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from text_generation_server.models.globals import PAGED_KV
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
|
from moe_kernels.fused_moe import grouped_topk
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
|
@ -33,6 +34,7 @@ from text_generation_server.layers.attention import (
|
||||||
Seqlen,
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
|
@ -153,44 +155,6 @@ class DeepseekV2Config(PretrainedConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_experts(config, prefix: str, mat: str, weights: Weights):
|
|
||||||
if config.quantize is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Deepseek V2 does not support weight quantization yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mat in ["gate_proj", "up_proj", "down_proj"]
|
|
||||||
|
|
||||||
world_size = weights.process_group.size()
|
|
||||||
rank = weights.process_group.rank()
|
|
||||||
|
|
||||||
assert (
|
|
||||||
config.moe_intermediate_size % world_size == 0
|
|
||||||
), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards"
|
|
||||||
|
|
||||||
block_size = config.moe_intermediate_size // world_size
|
|
||||||
start = rank * block_size
|
|
||||||
stop = (rank + 1) * block_size
|
|
||||||
|
|
||||||
tensor = torch.empty(
|
|
||||||
(config.n_routed_experts * block_size, config.hidden_size),
|
|
||||||
dtype=weights.dtype,
|
|
||||||
device=weights.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(config.n_routed_experts):
|
|
||||||
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
|
|
||||||
|
|
||||||
if mat == "down_proj":
|
|
||||||
expert_slice = slice_[:, start:stop].t().contiguous()
|
|
||||||
else:
|
|
||||||
expert_slice = slice_[start:stop]
|
|
||||||
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
|
|
||||||
dtype=weights.dtype
|
|
||||||
).to(device=weights.device)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Attention(torch.nn.Module):
|
class DeepseekV2Attention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -454,33 +418,21 @@ class BlockSparseMoE(nn.Module):
|
||||||
self.moe_intermediate_size = (
|
self.moe_intermediate_size = (
|
||||||
config.moe_intermediate_size // weights.process_group.size()
|
config.moe_intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
self.n_routed_experts = config.n_routed_experts
|
|
||||||
self.n_expert_group = config.n_group
|
|
||||||
self.topk_group = config.topk_group
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.norm_topk_prob = config.norm_topk_prob
|
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
|
||||||
gate_proj = _load_experts(
|
|
||||||
config, f"{prefix}.experts", "gate_proj", weights
|
|
||||||
).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
|
||||||
|
|
||||||
up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view(
|
|
||||||
self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1)
|
|
||||||
|
|
||||||
self.down_proj = (
|
|
||||||
_load_experts(config, f"{prefix}.experts", "down_proj", weights)
|
|
||||||
.view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Gating
|
# Gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
|
self.moe_layer = SparseMoELayer(
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
n_experts=config.n_routed_experts,
|
||||||
|
n_expert_group=config.n_group,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
topk=config.num_experts_per_tok,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
self.shared_experts = DeepseekV2MLP(
|
self.shared_experts = DeepseekV2MLP(
|
||||||
prefix=f"{prefix}.shared_experts",
|
prefix=f"{prefix}.shared_experts",
|
||||||
|
@ -501,25 +453,8 @@ class BlockSparseMoE(nn.Module):
|
||||||
shared_output = None
|
shared_output = None
|
||||||
|
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
x,
|
out = self.moe_layer(x, gating_output=router_logits)
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=self.norm_topk_prob,
|
|
||||||
num_expert_group=self.n_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
)
|
|
||||||
out = (
|
|
||||||
fused_experts(
|
|
||||||
x,
|
|
||||||
self.gate_up_proj,
|
|
||||||
self.down_proj,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
* self.routed_scaling_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
out = out + shared_output
|
out = out + shared_output
|
||||||
|
@ -637,7 +572,9 @@ class DeepseekV2Layer(nn.Module):
|
||||||
and layer_id >= config.first_k_dense_replace
|
and layer_id >= config.first_k_dense_replace
|
||||||
and layer_id % config.moe_layer_freq == 0
|
and layer_id % config.moe_layer_freq == 0
|
||||||
):
|
):
|
||||||
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
moe_cls = (
|
||||||
|
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||||
|
)
|
||||||
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
||||||
else:
|
else:
|
||||||
self.mlp = DeepseekV2MLP(
|
self.mlp = DeepseekV2MLP(
|
||||||
|
@ -801,183 +738,3 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
# Functions below are from vLLM:
|
|
||||||
#
|
|
||||||
# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397
|
|
||||||
#
|
|
||||||
# Remove after we have synced our version with upstream.
|
|
||||||
|
|
||||||
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(
|
|
||||||
M: int,
|
|
||||||
E: int,
|
|
||||||
N: int,
|
|
||||||
K: int,
|
|
||||||
topk: int,
|
|
||||||
dtype: Optional[str],
|
|
||||||
) -> Dict[str, int]:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 32,
|
|
||||||
"GROUP_SIZE_M": 8,
|
|
||||||
}
|
|
||||||
if M <= E:
|
|
||||||
config = {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 32,
|
|
||||||
"BLOCK_SIZE_K": 64,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
inplace: bool = False,
|
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
|
||||||
use_fp8: bool = False,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# Check constraints.
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
|
||||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
|
||||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
import triton.language as tl
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
||||||
get_moe_configs,
|
|
||||||
invoke_fused_moe_kernel,
|
|
||||||
moe_align_block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
E, N, _ = w1.shape
|
|
||||||
|
|
||||||
if override_config:
|
|
||||||
config = override_config
|
|
||||||
else:
|
|
||||||
# First try to load optimal config from the file
|
|
||||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
|
||||||
|
|
||||||
if configs:
|
|
||||||
# If an optimal configuration map has been found, look up the
|
|
||||||
# optimal config
|
|
||||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
||||||
else:
|
|
||||||
# Else use the default config
|
|
||||||
config = get_default_config(
|
|
||||||
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], N),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache2 = torch.empty(
|
|
||||||
(M * topk_ids.shape[1], N // 2),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
intermediate_cache3 = torch.empty(
|
|
||||||
(M, topk_ids.shape[1], w2.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
|
||||||
topk_ids, config["BLOCK_SIZE_M"], E
|
|
||||||
)
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
hidden_states,
|
|
||||||
w1,
|
|
||||||
intermediate_cache1,
|
|
||||||
a1_scale,
|
|
||||||
w1_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
False,
|
|
||||||
topk_ids.shape[1],
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
|
||||||
intermediate_cache2,
|
|
||||||
w2,
|
|
||||||
intermediate_cache3,
|
|
||||||
a2_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_padded,
|
|
||||||
True,
|
|
||||||
1,
|
|
||||||
config,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8=use_fp8,
|
|
||||||
)
|
|
||||||
|
|
||||||
if inplace:
|
|
||||||
return torch.sum(
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
dim=1,
|
|
||||||
out=hidden_states,
|
|
||||||
)
|
|
||||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
|
||||||
|
|
|
@ -24,10 +24,7 @@ import torch.distributed
|
||||||
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
if SYSTEM != "ipex":
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
@ -46,6 +43,7 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.moe import SparseMoELayer
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
@ -320,40 +318,21 @@ def round_up(x: torch.Tensor, value: int):
|
||||||
class BlockSparseMoE(nn.Module):
|
class BlockSparseMoE(nn.Module):
|
||||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
def __init__(self, prefix, config: MixtralConfig, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = config.hidden_size
|
|
||||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
|
||||||
self.num_experts = config.num_local_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
|
|
||||||
act = config.hidden_act
|
|
||||||
if "gelu" in act:
|
|
||||||
self.act = lambda x: torch.nn.functional.gelu(
|
|
||||||
x,
|
|
||||||
approximate=(
|
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif "silu" in act:
|
|
||||||
self.act = torch.nn.functional.silu
|
|
||||||
else:
|
|
||||||
self.act = ACT2FN[act]
|
|
||||||
|
|
||||||
# gating
|
# gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
|
self.moe = SparseMoELayer(
|
||||||
w1 = _load_experts(config, f"{prefix}.experts", "w1", weights).view(
|
n_expert_group=None,
|
||||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
n_experts=config.num_local_experts,
|
||||||
)
|
prefix=f"{prefix}.experts",
|
||||||
w3 = _load_experts(config, f"{prefix}.experts", "w3", weights).view(
|
renormalize=True,
|
||||||
self.num_experts, self.ffn_dim, self.hidden_dim
|
topk=config.num_experts_per_tok,
|
||||||
)
|
topk_group=None,
|
||||||
self.w13 = torch.cat([w1, w3], dim=1)
|
weights=weights,
|
||||||
self.w2 = (
|
gate_proj_name="w1",
|
||||||
_load_experts(config, f"{prefix}.experts", "w2", weights)
|
up_proj_name="w3",
|
||||||
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
|
down_proj_name="w2",
|
||||||
.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
@ -361,15 +340,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(x)
|
router_logits = self.gate(x)
|
||||||
out = fused_moe(
|
out = self.moe(x, gating_output=router_logits)
|
||||||
x,
|
|
||||||
self.w13,
|
|
||||||
self.w2,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reduce sum
|
# Reduce sum
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
|
@ -476,7 +447,7 @@ class MixtralLayer(nn.Module):
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
|
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
||||||
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
|
|
@ -82,7 +82,7 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
|
||||||
import numa
|
import numa
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
nodes = numa.get_max_node() + 1
|
nodes = numa.info.get_max_node() + 1
|
||||||
rank_per_node = math.ceil(world_size / nodes)
|
rank_per_node = math.ceil(world_size / nodes)
|
||||||
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
|
num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)
|
||||||
node_id = int(rank_id / rank_per_node)
|
node_id = int(rank_id / rank_per_node)
|
||||||
|
@ -91,18 +91,22 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
|
||||||
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
|
num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)
|
||||||
else:
|
else:
|
||||||
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
|
num_cpus_per_rank = int(os.getenv("OMP_NUM_THREADS"))
|
||||||
if len(numa.get_membind()) == nodes:
|
if len(numa.memory.get_membind_nodes()) == nodes:
|
||||||
numa.set_membind([node_id])
|
numa.memory.set_membind_nodes((node_id))
|
||||||
torch.set_num_threads(num_cpus_per_rank)
|
torch.set_num_threads(num_cpus_per_rank)
|
||||||
if len(numa.get_affinity(0)) == psutil.cpu_count(logical=True):
|
if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):
|
||||||
cpu_start = num_cpus_per_rank * rank_offset_per_node
|
cpu_start = num_cpus_per_rank * rank_offset_per_node
|
||||||
numa.set_affinity(
|
numa.schedule.run_on_cpus(
|
||||||
0,
|
0,
|
||||||
list(numa.node_to_cpus(node_id))[
|
*(
|
||||||
cpu_start : cpu_start + num_cpus_per_rank
|
numa.info.node_to_cpus(node_id)[
|
||||||
],
|
cpu_start : cpu_start + num_cpus_per_rank
|
||||||
|
]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logger.info(f"affinity={numa.get_affinity(0)}, membind = {numa.get_membind()}")
|
logger.info(
|
||||||
|
f"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -77,12 +77,12 @@ def load_and_merge_adapters(
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
|
||||||
if len(adapter_parameters.adapter_info) == 1:
|
if len(adapter_parameters.adapter_info) == 1:
|
||||||
adapter_info = next(iter(adapter_parameters.adapter_info))
|
adapter = next(iter(adapter_parameters.adapter_info))
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_info.revision,
|
adapter.revision,
|
||||||
adapter_info.id,
|
adapter.id,
|
||||||
adapter_info.path,
|
adapter.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -90,7 +90,6 @@ def load_and_merge_adapters(
|
||||||
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)
|
||||||
return _load_and_merge(
|
return _load_and_merge(
|
||||||
model_id,
|
model_id,
|
||||||
adapter_params.revision,
|
|
||||||
adapter_params,
|
adapter_params,
|
||||||
weight_names,
|
weight_names,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
|
@ -109,7 +108,6 @@ class AdapterParametersContainer:
|
||||||
@lru_cache(maxsize=32)
|
@lru_cache(maxsize=32)
|
||||||
def _load_and_merge(
|
def _load_and_merge(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: str,
|
|
||||||
adapter_params: AdapterParametersContainer,
|
adapter_params: AdapterParametersContainer,
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
@ -126,6 +124,7 @@ def _load_and_merge(
|
||||||
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
|
||||||
load_module_map(
|
load_module_map(
|
||||||
model_id,
|
model_id,
|
||||||
|
adapter.revision,
|
||||||
adapter.id,
|
adapter.id,
|
||||||
adapter.path,
|
adapter.path,
|
||||||
weight_names,
|
weight_names,
|
||||||
|
|
|
@ -120,7 +120,6 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
block_sizes: Union[int, List[int]],
|
block_sizes: Union[int, List[int]],
|
||||||
):
|
):
|
||||||
|
|
||||||
return self.weight_class(
|
return self.weight_class(
|
||||||
weights.get_packed_sharded(
|
weights.get_packed_sharded(
|
||||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
@ -405,6 +404,10 @@ class Weights:
|
||||||
finally:
|
finally:
|
||||||
self.weights_loader = old_loader
|
self.weights_loader = old_loader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loader(self):
|
||||||
|
return self.weights_loader
|
||||||
|
|
||||||
|
|
||||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue