hf_text-generation-inference/flake.nix

94 lines
2.3 KiB
Nix
Raw Normal View History

{
inputs = {
tgi-nix.url = "github:danieldk/tgi-nix";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
2024-08-09 07:24:21 -06:00
rust-overlay = {
url = "github:oxalica/rust-overlay";
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
};
outputs =
{
self,
nixpkgs,
flake-utils,
2024-08-09 07:24:21 -06:00
rust-overlay,
tgi-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
config = {
allowUnfree = true;
cudaSupport = true;
};
pkgs = import nixpkgs {
inherit config system;
2024-08-09 07:24:21 -06:00
overlays = [
rust-overlay.overlays.default
tgi-nix.overlay
];
};
in
{
devShells.default =
with pkgs;
mkShell {
buildInputs =
[
openssl.dev
pkg-config
2024-08-09 07:24:21 -06:00
(rust-bin.stable.latest.default.override {
extensions = [
"rust-analyzer"
"rust-src"
];
})
]
++ (with python3.pkgs; [
venvShellHook
pip
einops
fbgemm-gpu
flash-attn
flash-attn-layer-norm
flash-attn-rotary
grpc-interceptor
grpcio-reflection
grpcio-status
hf-transfer
loguru
marlin-kernels
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-instrumentation-grpc
opentelemetry-semantic-conventions
peft
tokenizers
torch
transformers
vllm
(callPackage ./router.nix {
inherit (rustPlatform) buildRustPackage importCargoLock;
})
2024-08-12 06:08:46 -06:00
(callPackage ./_launcher.nix {
inherit (rustPlatform) buildRustPackage importCargoLock;
})
]);
venvDir = "./.venv";
postVenv = ''
unset SOURCE_DATE_EPOCH
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
'';
};
}
);
}