Add some utility functions in tgiccl for now
This commit is contained in:
parent
105f384461
commit
31a6065fac
|
@ -16,6 +16,7 @@ option(TGI_BUILD_CCL "Flag to enable/disable build of tgiccl collective library"
|
||||||
|
|
||||||
# Add some modules
|
# Add some modules
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
include(cmake/spdlog.cmake)
|
||||||
|
|
||||||
# Let's find LibTorch
|
# Let's find LibTorch
|
||||||
include(cmake/torch.cmake)
|
include(cmake/torch.cmake)
|
||||||
|
@ -28,6 +29,5 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||||
|
|
||||||
# Include submodules
|
# Include submodules
|
||||||
if (${TGI_BUILD_CCL})
|
if (${TGI_BUILD_CCL})
|
||||||
include(cmake/nvshmem.cmake)
|
|
||||||
add_subdirectory(tgiccl)
|
add_subdirectory(tgiccl)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Release")
|
|
||||||
set(NVSHMEM_DEBUG OFF)
|
|
||||||
set(NVSHMEM_VERBOSE OFF)
|
|
||||||
else ()
|
|
||||||
set(NVSHMEM_DEBUG ON)
|
|
||||||
set(NVSHMEM_VERBOSE ON)
|
|
||||||
endif ()
|
|
||||||
|
|
||||||
fetchcontent_declare(
|
|
||||||
nvshmem
|
|
||||||
URL https://developer.download.nvidia.com/compute/redist/nvshmem/3.0.6/source/nvshmem_src_3.0.6-4.txz
|
|
||||||
DOWNLOAD_EXTRACT_TIMESTAMP
|
|
||||||
)
|
|
||||||
|
|
||||||
fetchcontent_makeavailable(nvshmem)
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
fetchcontent_declare(
|
||||||
|
spdlog
|
||||||
|
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
|
||||||
|
)
|
||||||
|
|
||||||
|
fetchcontent_makeavailable(spdlog)
|
|
@ -1,7 +1,12 @@
|
||||||
project(tgiccl)
|
project(tgiccl LANGUAGES C CXX CUDA)
|
||||||
|
|
||||||
set(TGICCL_HEADER_FILES tgiccl.hpp)
|
set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp)
|
||||||
#set(TGICCL_SOURCE_FILES)
|
set(TGICCL_SOURCES TgiCclBackend.cpp)
|
||||||
|
|
||||||
add_library(tgiccl SHARED ${TGICCL_HEADER_FILES})
|
find_package(CUDAToolkit REQUIRED)
|
||||||
target_link_libraries(tgiccl nvshmem)
|
|
||||||
|
add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES})
|
||||||
|
target_link_libraries(tgiccl PUBLIC spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES})
|
||||||
|
|
||||||
|
add_executable(test_tgiccl test_tgiccl.cpp)
|
||||||
|
target_link_libraries(test_tgiccl tgiccl spdlog::spdlog)
|
|
@ -0,0 +1,11 @@
|
||||||
|
//
|
||||||
|
// Created by Morgan Funtowicz on 26/09/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "TgiCclBackend.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
void huggingface::tgi::tgiccl::InitTgiCcl()
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
//
|
||||||
|
// Created by Morgan Funtowicz on 26/09/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef TGICCLPROCESSGROUP_H
|
||||||
|
#define TGICCLPROCESSGROUP_H
|
||||||
|
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||||
|
|
||||||
|
|
||||||
|
namespace huggingface::tgi::tgiccl
|
||||||
|
{
|
||||||
|
void InitTgiCcl();
|
||||||
|
|
||||||
|
class TgiCclBackend final : c10d::Backend {
|
||||||
|
public:
|
||||||
|
TgiCclBackend(const int rank, const int size): Backend(rank, size)
|
||||||
|
{
|
||||||
|
SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif //TGICCLPROCESSGROUP_H
|
|
@ -0,0 +1,11 @@
|
||||||
|
//
|
||||||
|
// Created by morgan on 26/09/24.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "tgiccl.hpp"
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1);
|
||||||
|
auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2);
|
||||||
|
auto c = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3);
|
||||||
|
}
|
|
@ -5,18 +5,52 @@
|
||||||
#ifndef TEXT_GENERATION_INFERENCE_TGICCL_H
|
#ifndef TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||||
#define TEXT_GENERATION_INFERENCE_TGICCL_H
|
#define TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||||
|
|
||||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
#include <optional>
|
||||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
||||||
|
|
||||||
constexpr const char *CLL_BACKEND_NAME = "tgiccl";
|
#include <nvml.h>
|
||||||
|
|
||||||
namespace huggingface::tgi {
|
#include "TgiCclBackend.hpp"
|
||||||
class TgiCcl {
|
|
||||||
private:
|
|
||||||
|
|
||||||
public:
|
constexpr auto CLL_BACKEND_NAME = "tgiccl";
|
||||||
|
|
||||||
|
namespace huggingface::tgi::tgiccl {
|
||||||
|
|
||||||
|
static std::once_flag NVML_INIT_FLAG;
|
||||||
|
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
|
||||||
|
|
||||||
|
inline std::optional<nvmlDevice_t> GetDeviceByIndex(const size_t index)
|
||||||
|
{
|
||||||
|
ENSURE_NVML_INIT();
|
||||||
|
|
||||||
|
nvmlDevice_t device;
|
||||||
|
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
|
||||||
|
return std::optional{ device };
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool IsNvLinkAvailable(const int from, const int to)
|
||||||
|
{
|
||||||
|
ENSURE_NVML_INIT();
|
||||||
|
|
||||||
|
// Get devices
|
||||||
|
const auto devFrom = GetDeviceByIndex(from);
|
||||||
|
const auto devTo = GetDeviceByIndex(to);
|
||||||
|
|
||||||
|
if(!devFrom.has_value())
|
||||||
|
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
|
||||||
|
|
||||||
|
if(!devTo.has_value())
|
||||||
|
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
|
||||||
|
|
||||||
|
// Query link between both
|
||||||
|
nvmlGpuP2PStatus_t status;
|
||||||
|
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
|
||||||
|
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
|
||||||
|
|
||||||
|
return status == NVML_P2P_STATUS_OK;
|
||||||
|
}
|
||||||
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif //TEXT_GENERATION_INFERENCE_TGICCL_H
|
#endif //TEXT_GENERATION_INFERENCE_TGICCL_H
|
||||||
|
|
Loading…
Reference in New Issue