diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index aa7db30..fe7cb70 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -95,6 +95,8 @@ jobs: file: Dockerfile push: ${{ github.event_name != 'pull_request' }} platforms: 'linux/amd64' + build-args: | + GIT_SHA={{ env.GITHUB_SHA }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=max @@ -176,6 +178,8 @@ jobs: file: Dockerfile push: ${{ github.event_name != 'pull_request' }} platforms: 'linux/amd64' + build-args: | + GIT_SHA={{ env.GITHUB_SHA }} target: sagemaker tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} diff --git a/Dockerfile b/Dockerfile index 61f37ed..a679db7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,8 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder +ARG GIT_SHA + RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ diff --git a/router/build.rs b/router/build.rs index c34f9fa..1b1fdc8 100644 --- a/router/build.rs +++ b/router/build.rs @@ -2,6 +2,18 @@ use std::error::Error; use vergen::EmitBuilder; fn main() -> Result<(), Box> { - EmitBuilder::builder().git_sha(false).emit()?; + // Try to get the git sha from the local git repository + if EmitBuilder::builder() + .fail_on_error() + .git_sha(false) + .emit() + .is_err() + { + // Unable to get the git sha + if let Ok(sha) = std::env::var("GIT_SHA") { + // Set it from an env var + println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); + } + } Ok(()) } diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0a29b3c..74a7483 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -17,13 +17,6 @@ from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.t5 import T5Sharded try: - from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - from text_generation_server.models.flash_llama import FlashLlama, FlashLlamaSharded - from text_generation_server.models.flash_santacoder import ( - FlashSantacoder, - FlashSantacoderSharded, - ) - if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -32,7 +25,20 @@ try: supported = is_sm75 or is_sm8x or is_sm90 if not supported: - raise ImportError(f"GPU with CUDA capability {major} {minor} is not supported") + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) + + from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + from text_generation_server.models.flash_llama import ( + FlashLlama, + FlashLlamaSharded, + ) + from text_generation_server.models.flash_santacoder import ( + FlashSantacoder, + FlashSantacoderSharded, + ) + FLASH_ATTENTION = True else: FLASH_ATTENTION = False