diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 2e0f8588..fddd734b 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -22,6 +22,7 @@ include(cmake/spdlog.cmake) # Let's find LibTorch include(cmake/torch.cmake) find_package(Torch REQUIRED) +find_package(CUDAToolkit REQUIRED) find_package(Python3 COMPONENTS Interpreter) # TGI common diff --git a/csrc/cmake/torch.cmake b/csrc/cmake/torch.cmake index a8c2af15..b818d468 100644 --- a/csrc/cmake/torch.cmake +++ b/csrc/cmake/torch.cmake @@ -1,7 +1,7 @@ fetchcontent_declare( - Torch + torch URL https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip # OVERRIDE_FIND_PACKAGE ) -FetchContent_MakeAvailable(Torch) -list(APPEND CMAKE_PREFIX_PATH ${Torch_SOURCE_DIR}) +FetchContent_MakeAvailable(torch) +list(APPEND CMAKE_PREFIX_PATH ${torch_SOURCE_DIR}) diff --git a/csrc/common/CMakeLists.txt b/csrc/common/CMakeLists.txt index 831c0821..bd44e4ed 100644 --- a/csrc/common/CMakeLists.txt +++ b/csrc/common/CMakeLists.txt @@ -2,8 +2,8 @@ set(TGI_COMMON_HEADERS include/common/device.hpp) set(TGI_COMMON_SOURCES lib/device.cpp) -add_library(tgi_common SHARED ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES}) -target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES}) +add_library(tgi_common STATIC ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES}) +target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES} CUDA::nvml) target_include_directories(tgi_common PRIVATE $ diff --git a/csrc/common/include/common/device.hpp b/csrc/common/include/common/device.hpp index 183e78f8..c0f10dec 100644 --- a/csrc/common/include/common/device.hpp +++ b/csrc/common/include/common/device.hpp @@ -7,8 +7,12 @@ #include #include #include +#include namespace huggingface::tgi { + static std::once_flag NVML_INIT_FLAG; +#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2); + using device_index_t = uint8_t; /** diff --git a/csrc/common/lib/device.cpp b/csrc/common/lib/device.cpp index a703f40b..c80a0d5f 100644 --- a/csrc/common/lib/device.cpp +++ b/csrc/common/lib/device.cpp @@ -1,20 +1,48 @@ // // Created by morgan on 27/09/24. // +#include +#include #include #include "device.hpp" -std::optional huggingface::tgi::GetDeviceByIndex(device_index_t device) +std::optional huggingface::tgi::GetDeviceByIndex(device_index_t index) { + ENSURE_NVML_INIT(); + + nvmlDevice_t device; + if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS) + return std::optional{ device }; + return std::nullopt; } +bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to) +{ + ENSURE_NVML_INIT(); + + // Get devices + const auto devFrom = GetDeviceByIndex(from); + const auto devTo = GetDeviceByIndex(to); + + if(!devFrom.has_value()) + SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from); + + if(!devTo.has_value()) + SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to); + + // Query link between both + nvmlGpuP2PStatus_t status; + if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS) + { + SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to); + return false; + } + + return status == NVML_P2P_STATUS_OK; +} + bool huggingface::tgi::IsP2PComplete() { return false; } - -bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to) -{ - return false; -} diff --git a/csrc/tgiccl/CMakeLists.txt b/csrc/tgiccl/CMakeLists.txt index 1437371a..e0bcc68b 100644 --- a/csrc/tgiccl/CMakeLists.txt +++ b/csrc/tgiccl/CMakeLists.txt @@ -3,9 +3,7 @@ project(tgiccl LANGUAGES C CXX CUDA) set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp) set(TGICCL_SOURCES TgiCclBackend.cpp) -find_package(CUDAToolkit REQUIRED) - -add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES}) +add_library(tgiccl STATIC ${TGICCL_HEADERS} ${TGICCL_SOURCES}) target_link_libraries(tgiccl PUBLIC tgi_common fmt::fmt spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES}) add_executable(test_tgiccl test_tgiccl.cpp) diff --git a/csrc/tgiccl/TgiCclBackend.cpp b/csrc/tgiccl/TgiCclBackend.cpp index 1c4250bc..19c5cdfc 100644 --- a/csrc/tgiccl/TgiCclBackend.cpp +++ b/csrc/tgiccl/TgiCclBackend.cpp @@ -5,7 +5,42 @@ #include "TgiCclBackend.hpp" +auto fmt::formatter::format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator { + string_view name = "unknown"; + switch (op) { + case c10d::ReduceOp::AVG: name = "ReduceOp::AVG"; break; + case c10d::ReduceOp::BAND: name = "ReduceOp::BAND"; break; + case c10d::ReduceOp::BOR: name = "ReduceOp::BOR"; break; + case c10d::ReduceOp::BXOR: name = "ReduceOp::BXOR"; break; + case c10d::ReduceOp::MAX: name = "ReduceOp::MAX"; break; + case c10d::ReduceOp::MIN: name = "ReduceOp::MIN"; break; + case c10d::ReduceOp::PREMUL_SUM: name = "ReduceOp::PREMUL_SUM"; break; + case c10d::ReduceOp::PRODUCT: name = "ReduceOp::PRODUCT"; break; + case c10d::ReduceOp::SUM: name = "ReduceOp::SUM"; break; + case c10d::ReduceOp::UNUSED: name = "ReduceOp::UNUSED"; break; + } + return formatter::format(name, ctx); +} + + + void huggingface::tgi::tgiccl::InitTgiCcl() { -} \ No newline at end of file +} + +huggingface::tgi::tgiccl::TgiCclBackend::TgiCclBackend(const int rank, const int size) : Backend(rank, size) { + SPDLOG_INFO(FMT_STRING("Creating {} on rank {:d} (world_size={:d})"), getBackendName(), rank, size); + +} + +const std::string huggingface::tgi::tgiccl::TgiCclBackend::getBackendName() const { + return CCL_BACKEND_NAME; +} + +c10::intrusive_ptr +huggingface::tgi::tgiccl::TgiCclBackend::allreduce(std::vector &tensors, const c10d::AllreduceOptions &options) { + TORCH_CHECK(options.reduceOp == c10d::ReduceOp::SUM, fmt::format(FMT_STRING("tgiccl only supports ReduceOp::SUM, got {}"), options.reduceOp)) + tensors[0] += 1; + return c10::make_intrusive(); +} diff --git a/csrc/tgiccl/TgiCclBackend.hpp b/csrc/tgiccl/TgiCclBackend.hpp index cbf6e0e1..a0ed6994 100644 --- a/csrc/tgiccl/TgiCclBackend.hpp +++ b/csrc/tgiccl/TgiCclBackend.hpp @@ -9,17 +9,30 @@ #include +template <> struct fmt::formatter: formatter { + auto format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator; +}; + + + namespace huggingface::tgi::tgiccl { +#define CCL_BACKEND_NAME "tgiccl"; + void InitTgiCcl(); + class TgiCclBackend; + class TgiCclBackendWork final: c10d::Work { + friend TgiCclBackend; + + + }; + + class TgiCclBackend final : c10d::Backend { public: - TgiCclBackend(const int rank, const int size): Backend(rank, size) - { - SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size); - } - + TgiCclBackend(int rank, int size); + const std::string getBackendName() const override; c10::intrusive_ptr allreduce(std::vector&, const c10d::AllreduceOptions&) override; }; } diff --git a/csrc/tgiccl/test_tgiccl.cpp b/csrc/tgiccl/test_tgiccl.cpp index f7a25f6b..84bef94d 100644 --- a/csrc/tgiccl/test_tgiccl.cpp +++ b/csrc/tgiccl/test_tgiccl.cpp @@ -2,10 +2,14 @@ // Created by morgan on 26/09/24. // +#include +#include #include "tgiccl.hpp" int main() { - auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1); - auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2); - auto d = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3); + auto backend = huggingface::tgi::tgiccl::TgiCclBackend(0, 4); + auto tensor = torch::zeros({128}); + auto tensors = std::vector(); + tensors.push_back(tensor); + backend.allreduce(tensors, c10d::AllreduceOptions()); } \ No newline at end of file diff --git a/csrc/tgiccl/tgiccl.hpp b/csrc/tgiccl/tgiccl.hpp index 782ee116..ca509673 100644 --- a/csrc/tgiccl/tgiccl.hpp +++ b/csrc/tgiccl/tgiccl.hpp @@ -11,48 +11,8 @@ #include "TgiCclBackend.hpp" -constexpr auto CLL_BACKEND_NAME = "tgiccl"; - namespace huggingface::tgi::tgiccl { - static std::once_flag NVML_INIT_FLAG; -#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2); - - inline std::optional GetDeviceByIndex(const size_t index) - { - ENSURE_NVML_INIT(); - - nvmlDevice_t device; - if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS) - return std::optional{ device }; - - return std::nullopt; - } - - inline bool IsNvLinkAvailable(const int from, const int to) - { - ENSURE_NVML_INIT(); - - // Get devices - const auto devFrom = GetDeviceByIndex(from); - const auto devTo = GetDeviceByIndex(to); - - if(!devFrom.has_value()) - SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from); - - if(!devTo.has_value()) - SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to); - - // Query link between both - nvmlGpuP2PStatus_t status; - if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS) - { - SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to); - return false; - } - - return status == NVML_P2P_STATUS_OK; - } }