Add some utility functions in tgiccl for now
This commit is contained in:
parent
105f384461
commit
31a6065fac
|
@ -16,6 +16,7 @@ option(TGI_BUILD_CCL "Flag to enable/disable build of tgiccl collective library"
|
|||
|
||||
# Add some modules
|
||||
include(FetchContent)
|
||||
include(cmake/spdlog.cmake)
|
||||
|
||||
# Let's find LibTorch
|
||||
include(cmake/torch.cmake)
|
||||
|
@ -28,6 +29,5 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
|||
|
||||
# Include submodules
|
||||
if (${TGI_BUILD_CCL})
|
||||
include(cmake/nvshmem.cmake)
|
||||
add_subdirectory(tgiccl)
|
||||
endif ()
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
if (CMAKE_BUILD_TYPE STREQUAL "Release")
|
||||
set(NVSHMEM_DEBUG OFF)
|
||||
set(NVSHMEM_VERBOSE OFF)
|
||||
else ()
|
||||
set(NVSHMEM_DEBUG ON)
|
||||
set(NVSHMEM_VERBOSE ON)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
nvshmem
|
||||
URL https://developer.download.nvidia.com/compute/redist/nvshmem/3.0.6/source/nvshmem_src_3.0.6-4.txz
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||
)
|
||||
|
||||
fetchcontent_makeavailable(nvshmem)
|
|
@ -0,0 +1,6 @@
|
|||
fetchcontent_declare(
|
||||
spdlog
|
||||
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
|
||||
)
|
||||
|
||||
fetchcontent_makeavailable(spdlog)
|
|
@ -1,7 +1,12 @@
|
|||
project(tgiccl)
|
||||
project(tgiccl LANGUAGES C CXX CUDA)
|
||||
|
||||
set(TGICCL_HEADER_FILES tgiccl.hpp)
|
||||
#set(TGICCL_SOURCE_FILES)
|
||||
set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp)
|
||||
set(TGICCL_SOURCES TgiCclBackend.cpp)
|
||||
|
||||
add_library(tgiccl SHARED ${TGICCL_HEADER_FILES})
|
||||
target_link_libraries(tgiccl nvshmem)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES})
|
||||
target_link_libraries(tgiccl PUBLIC spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES})
|
||||
|
||||
add_executable(test_tgiccl test_tgiccl.cpp)
|
||||
target_link_libraries(test_tgiccl tgiccl spdlog::spdlog)
|
|
@ -0,0 +1,11 @@
|
|||
//
|
||||
// Created by Morgan Funtowicz on 26/09/24.
|
||||
//
|
||||
|
||||
#include "TgiCclBackend.hpp"
|
||||
|
||||
|
||||
void huggingface::tgi::tgiccl::InitTgiCcl()
|
||||
{
|
||||
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
//
|
||||
// Created by Morgan Funtowicz on 26/09/24.
|
||||
//
|
||||
|
||||
#ifndef TGICCLPROCESSGROUP_H
|
||||
#define TGICCLPROCESSGROUP_H
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
|
||||
|
||||
namespace huggingface::tgi::tgiccl
|
||||
{
|
||||
void InitTgiCcl();
|
||||
|
||||
class TgiCclBackend final : c10d::Backend {
|
||||
public:
|
||||
TgiCclBackend(const int rank, const int size): Backend(rank, size)
|
||||
{
|
||||
SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif //TGICCLPROCESSGROUP_H
|
|
@ -0,0 +1,11 @@
|
|||
//
|
||||
// Created by morgan on 26/09/24.
|
||||
//
|
||||
|
||||
#include "tgiccl.hpp"
|
||||
|
||||
int main() {
|
||||
auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1);
|
||||
auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2);
|
||||
auto c = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3);
|
||||
}
|
|
@ -5,18 +5,52 @@
|
|||
#ifndef TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||
#define TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <optional>
|
||||
|
||||
constexpr const char *CLL_BACKEND_NAME = "tgiccl";
|
||||
#include <nvml.h>
|
||||
|
||||
namespace huggingface::tgi {
|
||||
class TgiCcl {
|
||||
private:
|
||||
#include "TgiCclBackend.hpp"
|
||||
|
||||
public:
|
||||
constexpr auto CLL_BACKEND_NAME = "tgiccl";
|
||||
|
||||
namespace huggingface::tgi::tgiccl {
|
||||
|
||||
static std::once_flag NVML_INIT_FLAG;
|
||||
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
|
||||
|
||||
inline std::optional<nvmlDevice_t> GetDeviceByIndex(const size_t index)
|
||||
{
|
||||
ENSURE_NVML_INIT();
|
||||
|
||||
nvmlDevice_t device;
|
||||
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
|
||||
return std::optional{ device };
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
inline bool IsNvLinkAvailable(const int from, const int to)
|
||||
{
|
||||
ENSURE_NVML_INIT();
|
||||
|
||||
// Get devices
|
||||
const auto devFrom = GetDeviceByIndex(from);
|
||||
const auto devTo = GetDeviceByIndex(to);
|
||||
|
||||
if(!devFrom.has_value())
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
|
||||
|
||||
if(!devTo.has_value())
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
|
||||
|
||||
// Query link between both
|
||||
nvmlGpuP2PStatus_t status;
|
||||
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
|
||||
|
||||
return status == NVML_P2P_STATUS_OK;
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
#endif //TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||
|
|
Loading…
Reference in New Issue