Add `impureWithCuda` dev shell (#2677)

* Add `impureWithCuda` dev shell

This shell is handy when developing some kernels jointly with TGI - it
adds nvcc and a bunch of commonly-used CUDA libraries to the environment.

We don't add this to the normal impure shell to keep the development
environment as clean as possible (avoid accidental dependencies, etc.).

* Add cuDNN
This commit is contained in:
Daniël de Kok 2024-10-22 11:02:55 +02:00 committed by GitHub
parent 058d3061f7
commit 9c9ef37c56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 3 deletions

View File

@ -137,6 +137,11 @@
impure = callPackage ./nix/impure-shell.nix { inherit server; }; impure = callPackage ./nix/impure-shell.nix { inherit server; };
impureWithCuda = callPackage ./nix/impure-shell.nix {
inherit server;
withCuda = true;
};
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix { impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; }; server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
}; };

View File

@ -1,7 +1,12 @@
{ {
lib,
mkShell, mkShell,
black, black,
cmake,
isort, isort,
ninja,
which,
cudaPackages,
openssl, openssl,
pkg-config, pkg-config,
protobuf, protobuf,
@ -11,14 +16,17 @@
ruff, ruff,
rust-bin, rust-bin,
server, server,
# Enable dependencies for building CUDA packages. Useful for e.g.
# developing marlin/moe-kernels in-place.
withCuda ? false,
}: }:
mkShell { mkShell {
buildInputs = nativeBuildInputs =
[ [
black black
isort isort
openssl.dev
pkg-config pkg-config
(rust-bin.stable.latest.default.override { (rust-bin.stable.latest.default.override {
extensions = [ extensions = [
@ -31,6 +39,19 @@ mkShell {
redocly redocly
ruff ruff
] ]
++ (lib.optionals withCuda [
cmake
ninja
which
# For most Torch-based extensions, setting CUDA_HOME is enough, but
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
cudaPackages.cuda_nvcc
]);
buildInputs =
[
openssl.dev
]
++ (with python3.pkgs; [ ++ (with python3.pkgs; [
venvShellHook venvShellHook
docker docker
@ -40,10 +61,27 @@ mkShell {
pytest pytest
pytest-asyncio pytest-asyncio
syrupy syrupy
]); ])
++ (lib.optionals withCuda (
with cudaPackages;
[
cuda_cccl
cuda_cudart
cuda_nvtx
cudnn
libcublas
libcusolver
libcusparse
]
));
inputsFrom = [ server ]; inputsFrom = [ server ];
env = lib.optionalAttrs withCuda {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
};
venvDir = "./.venv"; venvDir = "./.venv";
postVenvCreation = '' postVenvCreation = ''
@ -51,6 +89,7 @@ mkShell {
( cd server ; python -m pip install --no-dependencies -e . ) ( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . ) ( cd clients/python ; python -m pip install --no-dependencies -e . )
''; '';
postShellHook = '' postShellHook = ''
unset SOURCE_DATE_EPOCH unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin export PATH=$PATH:~/.cargo/bin