feat(tgi_common) continue more utility functions
This commit is contained in:
parent
fb81ffce02
commit
513ba5a0b4
|
@ -22,6 +22,7 @@ include(cmake/spdlog.cmake)
|
||||||
# Let's find LibTorch
|
# Let's find LibTorch
|
||||||
include(cmake/torch.cmake)
|
include(cmake/torch.cmake)
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
find_package(CUDAToolkit REQUIRED)
|
||||||
find_package(Python3 COMPONENTS Interpreter)
|
find_package(Python3 COMPONENTS Interpreter)
|
||||||
|
|
||||||
# TGI common
|
# TGI common
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
Torch
|
torch
|
||||||
URL https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip
|
URL https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip
|
||||||
# OVERRIDE_FIND_PACKAGE
|
# OVERRIDE_FIND_PACKAGE
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(Torch)
|
FetchContent_MakeAvailable(torch)
|
||||||
list(APPEND CMAKE_PREFIX_PATH ${Torch_SOURCE_DIR})
|
list(APPEND CMAKE_PREFIX_PATH ${torch_SOURCE_DIR})
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
set(TGI_COMMON_HEADERS include/common/device.hpp)
|
set(TGI_COMMON_HEADERS include/common/device.hpp)
|
||||||
set(TGI_COMMON_SOURCES lib/device.cpp)
|
set(TGI_COMMON_SOURCES lib/device.cpp)
|
||||||
|
|
||||||
add_library(tgi_common SHARED ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES})
|
add_library(tgi_common STATIC ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES})
|
||||||
target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES})
|
target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES} CUDA::nvml)
|
||||||
|
|
||||||
target_include_directories(tgi_common PRIVATE
|
target_include_directories(tgi_common PRIVATE
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/common>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/common>
|
||||||
|
|
|
@ -7,8 +7,12 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <nvml.h>
|
#include <nvml.h>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
namespace huggingface::tgi {
|
namespace huggingface::tgi {
|
||||||
|
static std::once_flag NVML_INIT_FLAG;
|
||||||
|
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
|
||||||
|
|
||||||
using device_index_t = uint8_t;
|
using device_index_t = uint8_t;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1,20 +1,48 @@
|
||||||
//
|
//
|
||||||
// Created by morgan on 27/09/24.
|
// Created by morgan on 27/09/24.
|
||||||
//
|
//
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <spdlog/spdlog.h>
|
||||||
#include <nvml.h>
|
#include <nvml.h>
|
||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
|
|
||||||
std::optional<nvmlDevice_t> huggingface::tgi::GetDeviceByIndex(device_index_t device)
|
std::optional<nvmlDevice_t> huggingface::tgi::GetDeviceByIndex(device_index_t index)
|
||||||
{
|
{
|
||||||
|
ENSURE_NVML_INIT();
|
||||||
|
|
||||||
|
nvmlDevice_t device;
|
||||||
|
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
|
||||||
|
return std::optional{ device };
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to)
|
||||||
|
{
|
||||||
|
ENSURE_NVML_INIT();
|
||||||
|
|
||||||
|
// Get devices
|
||||||
|
const auto devFrom = GetDeviceByIndex(from);
|
||||||
|
const auto devTo = GetDeviceByIndex(to);
|
||||||
|
|
||||||
|
if(!devFrom.has_value())
|
||||||
|
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
|
||||||
|
|
||||||
|
if(!devTo.has_value())
|
||||||
|
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
|
||||||
|
|
||||||
|
// Query link between both
|
||||||
|
nvmlGpuP2PStatus_t status;
|
||||||
|
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
|
||||||
|
{
|
||||||
|
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return status == NVML_P2P_STATUS_OK;
|
||||||
|
}
|
||||||
|
|
||||||
bool huggingface::tgi::IsP2PComplete()
|
bool huggingface::tgi::IsP2PComplete()
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to)
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,9 +3,7 @@ project(tgiccl LANGUAGES C CXX CUDA)
|
||||||
set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp)
|
set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp)
|
||||||
set(TGICCL_SOURCES TgiCclBackend.cpp)
|
set(TGICCL_SOURCES TgiCclBackend.cpp)
|
||||||
|
|
||||||
find_package(CUDAToolkit REQUIRED)
|
add_library(tgiccl STATIC ${TGICCL_HEADERS} ${TGICCL_SOURCES})
|
||||||
|
|
||||||
add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES})
|
|
||||||
target_link_libraries(tgiccl PUBLIC tgi_common fmt::fmt spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES})
|
target_link_libraries(tgiccl PUBLIC tgi_common fmt::fmt spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES})
|
||||||
|
|
||||||
add_executable(test_tgiccl test_tgiccl.cpp)
|
add_executable(test_tgiccl test_tgiccl.cpp)
|
||||||
|
|
|
@ -5,7 +5,42 @@
|
||||||
#include "TgiCclBackend.hpp"
|
#include "TgiCclBackend.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
auto fmt::formatter<c10d::ReduceOp>::format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator {
|
||||||
|
string_view name = "unknown";
|
||||||
|
switch (op) {
|
||||||
|
case c10d::ReduceOp::AVG: name = "ReduceOp::AVG"; break;
|
||||||
|
case c10d::ReduceOp::BAND: name = "ReduceOp::BAND"; break;
|
||||||
|
case c10d::ReduceOp::BOR: name = "ReduceOp::BOR"; break;
|
||||||
|
case c10d::ReduceOp::BXOR: name = "ReduceOp::BXOR"; break;
|
||||||
|
case c10d::ReduceOp::MAX: name = "ReduceOp::MAX"; break;
|
||||||
|
case c10d::ReduceOp::MIN: name = "ReduceOp::MIN"; break;
|
||||||
|
case c10d::ReduceOp::PREMUL_SUM: name = "ReduceOp::PREMUL_SUM"; break;
|
||||||
|
case c10d::ReduceOp::PRODUCT: name = "ReduceOp::PRODUCT"; break;
|
||||||
|
case c10d::ReduceOp::SUM: name = "ReduceOp::SUM"; break;
|
||||||
|
case c10d::ReduceOp::UNUSED: name = "ReduceOp::UNUSED"; break;
|
||||||
|
}
|
||||||
|
return formatter<string_view>::format(name, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void huggingface::tgi::tgiccl::InitTgiCcl()
|
void huggingface::tgi::tgiccl::InitTgiCcl()
|
||||||
{
|
{
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
huggingface::tgi::tgiccl::TgiCclBackend::TgiCclBackend(const int rank, const int size) : Backend(rank, size) {
|
||||||
|
SPDLOG_INFO(FMT_STRING("Creating {} on rank {:d} (world_size={:d})"), getBackendName(), rank, size);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string huggingface::tgi::tgiccl::TgiCclBackend::getBackendName() const {
|
||||||
|
return CCL_BACKEND_NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<c10d::Work>
|
||||||
|
huggingface::tgi::tgiccl::TgiCclBackend::allreduce(std::vector<at::Tensor> &tensors, const c10d::AllreduceOptions &options) {
|
||||||
|
TORCH_CHECK(options.reduceOp == c10d::ReduceOp::SUM, fmt::format(FMT_STRING("tgiccl only supports ReduceOp::SUM, got {}"), options.reduceOp))
|
||||||
|
tensors[0] += 1;
|
||||||
|
return c10::make_intrusive<c10d::Work>();
|
||||||
|
}
|
||||||
|
|
|
@ -9,17 +9,30 @@
|
||||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||||
|
|
||||||
|
|
||||||
|
template <> struct fmt::formatter<c10d::ReduceOp>: formatter<string_view> {
|
||||||
|
auto format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::tgiccl
|
namespace huggingface::tgi::tgiccl
|
||||||
{
|
{
|
||||||
|
#define CCL_BACKEND_NAME "tgiccl";
|
||||||
|
|
||||||
void InitTgiCcl();
|
void InitTgiCcl();
|
||||||
|
|
||||||
|
class TgiCclBackend;
|
||||||
|
class TgiCclBackendWork final: c10d::Work {
|
||||||
|
friend TgiCclBackend;
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class TgiCclBackend final : c10d::Backend {
|
class TgiCclBackend final : c10d::Backend {
|
||||||
public:
|
public:
|
||||||
TgiCclBackend(const int rank, const int size): Backend(rank, size)
|
TgiCclBackend(int rank, int size);
|
||||||
{
|
const std::string getBackendName() const override;
|
||||||
SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override;
|
c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,14 @@
|
||||||
// Created by morgan on 26/09/24.
|
// Created by morgan on 26/09/24.
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <torch/torch.h>
|
||||||
#include "tgiccl.hpp"
|
#include "tgiccl.hpp"
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1);
|
auto backend = huggingface::tgi::tgiccl::TgiCclBackend(0, 4);
|
||||||
auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2);
|
auto tensor = torch::zeros({128});
|
||||||
auto d = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3);
|
auto tensors = std::vector<torch::Tensor>();
|
||||||
|
tensors.push_back(tensor);
|
||||||
|
backend.allreduce(tensors, c10d::AllreduceOptions());
|
||||||
}
|
}
|
|
@ -11,48 +11,8 @@
|
||||||
|
|
||||||
#include "TgiCclBackend.hpp"
|
#include "TgiCclBackend.hpp"
|
||||||
|
|
||||||
constexpr auto CLL_BACKEND_NAME = "tgiccl";
|
|
||||||
|
|
||||||
namespace huggingface::tgi::tgiccl
|
namespace huggingface::tgi::tgiccl
|
||||||
{
|
{
|
||||||
static std::once_flag NVML_INIT_FLAG;
|
|
||||||
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
|
|
||||||
|
|
||||||
inline std::optional<nvmlDevice_t> GetDeviceByIndex(const size_t index)
|
|
||||||
{
|
|
||||||
ENSURE_NVML_INIT();
|
|
||||||
|
|
||||||
nvmlDevice_t device;
|
|
||||||
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
|
|
||||||
return std::optional{ device };
|
|
||||||
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool IsNvLinkAvailable(const int from, const int to)
|
|
||||||
{
|
|
||||||
ENSURE_NVML_INIT();
|
|
||||||
|
|
||||||
// Get devices
|
|
||||||
const auto devFrom = GetDeviceByIndex(from);
|
|
||||||
const auto devTo = GetDeviceByIndex(to);
|
|
||||||
|
|
||||||
if(!devFrom.has_value())
|
|
||||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
|
|
||||||
|
|
||||||
if(!devTo.has_value())
|
|
||||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
|
|
||||||
|
|
||||||
// Query link between both
|
|
||||||
nvmlGpuP2PStatus_t status;
|
|
||||||
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
|
|
||||||
{
|
|
||||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return status == NVML_P2P_STATUS_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue