Add some utility functions in tgiccl for now

This commit is contained in:
Morgan Funtowicz 2024-09-26 23:31:07 +02:00
parent 105f384461
commit 31a6065fac
8 changed files with 110 additions and 29 deletions

View File

@ -16,6 +16,7 @@ option(TGI_BUILD_CCL "Flag to enable/disable build of tgiccl collective library"
# Add some modules # Add some modules
include(FetchContent) include(FetchContent)
include(cmake/spdlog.cmake)
# Let's find LibTorch # Let's find LibTorch
include(cmake/torch.cmake) include(cmake/torch.cmake)
@ -28,6 +29,5 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
# Include submodules # Include submodules
if (${TGI_BUILD_CCL}) if (${TGI_BUILD_CCL})
include(cmake/nvshmem.cmake)
add_subdirectory(tgiccl) add_subdirectory(tgiccl)
endif () endif ()

View File

@ -1,15 +0,0 @@
if (CMAKE_BUILD_TYPE STREQUAL "Release")
set(NVSHMEM_DEBUG OFF)
set(NVSHMEM_VERBOSE OFF)
else ()
set(NVSHMEM_DEBUG ON)
set(NVSHMEM_VERBOSE ON)
endif ()
fetchcontent_declare(
nvshmem
URL https://developer.download.nvidia.com/compute/redist/nvshmem/3.0.6/source/nvshmem_src_3.0.6-4.txz
DOWNLOAD_EXTRACT_TIMESTAMP
)
fetchcontent_makeavailable(nvshmem)

6
csrc/cmake/spdlog.cmake Normal file
View File

@ -0,0 +1,6 @@
fetchcontent_declare(
spdlog
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
)
fetchcontent_makeavailable(spdlog)

View File

@ -1,7 +1,12 @@
project(tgiccl) project(tgiccl LANGUAGES C CXX CUDA)
set(TGICCL_HEADER_FILES tgiccl.hpp) set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp)
#set(TGICCL_SOURCE_FILES) set(TGICCL_SOURCES TgiCclBackend.cpp)
add_library(tgiccl SHARED ${TGICCL_HEADER_FILES}) find_package(CUDAToolkit REQUIRED)
target_link_libraries(tgiccl nvshmem)
add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES})
target_link_libraries(tgiccl PUBLIC spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES})
add_executable(test_tgiccl test_tgiccl.cpp)
target_link_libraries(test_tgiccl tgiccl spdlog::spdlog)

View File

@ -0,0 +1,11 @@
//
// Created by Morgan Funtowicz on 26/09/24.
//
#include "TgiCclBackend.hpp"
void huggingface::tgi::tgiccl::InitTgiCcl()
{
}

View File

@ -0,0 +1,29 @@
//
// Created by Morgan Funtowicz on 26/09/24.
//
#ifndef TGICCLPROCESSGROUP_H
#define TGICCLPROCESSGROUP_H
#include <spdlog/spdlog.h>
#include <torch/csrc/distributed/c10d/Backend.hpp>
namespace huggingface::tgi::tgiccl
{
void InitTgiCcl();
class TgiCclBackend final : c10d::Backend {
public:
TgiCclBackend(const int rank, const int size): Backend(rank, size)
{
SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size);
}
c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override;
};
}
#endif //TGICCLPROCESSGROUP_H

View File

@ -0,0 +1,11 @@
//
// Created by morgan on 26/09/24.
//
#include "tgiccl.hpp"
int main() {
auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1);
auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2);
auto c = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3);
}

View File

@ -5,18 +5,52 @@
#ifndef TEXT_GENERATION_INFERENCE_TGICCL_H #ifndef TEXT_GENERATION_INFERENCE_TGICCL_H
#define TEXT_GENERATION_INFERENCE_TGICCL_H #define TEXT_GENERATION_INFERENCE_TGICCL_H
#include <torch/csrc/distributed/c10d/Backend.hpp> #include <optional>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
constexpr const char *CLL_BACKEND_NAME = "tgiccl"; #include <nvml.h>
namespace huggingface::tgi { #include "TgiCclBackend.hpp"
class TgiCcl {
private:
public: constexpr auto CLL_BACKEND_NAME = "tgiccl";
namespace huggingface::tgi::tgiccl {
static std::once_flag NVML_INIT_FLAG;
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
inline std::optional<nvmlDevice_t> GetDeviceByIndex(const size_t index)
{
ENSURE_NVML_INIT();
nvmlDevice_t device;
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
return std::optional{ device };
return std::nullopt;
}
inline bool IsNvLinkAvailable(const int from, const int to)
{
ENSURE_NVML_INIT();
// Get devices
const auto devFrom = GetDeviceByIndex(from);
const auto devTo = GetDeviceByIndex(to);
if(!devFrom.has_value())
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
if(!devTo.has_value())
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
// Query link between both
nvmlGpuP2PStatus_t status;
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
return status == NVML_P2P_STATUS_OK;
}
};
} }
#endif //TEXT_GENERATION_INFERENCE_TGICCL_H #endif //TEXT_GENERATION_INFERENCE_TGICCL_H