feat(tgi_common) continue more utility functions
This commit is contained in:
parent
fb81ffce02
commit
513ba5a0b4
|
@ -22,6 +22,7 @@ include(cmake/spdlog.cmake)
|
|||
# Let's find LibTorch
|
||||
include(cmake/torch.cmake)
|
||||
find_package(Torch REQUIRED)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
find_package(Python3 COMPONENTS Interpreter)
|
||||
|
||||
# TGI common
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
fetchcontent_declare(
|
||||
Torch
|
||||
torch
|
||||
URL https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip
|
||||
# OVERRIDE_FIND_PACKAGE
|
||||
)
|
||||
FetchContent_MakeAvailable(Torch)
|
||||
list(APPEND CMAKE_PREFIX_PATH ${Torch_SOURCE_DIR})
|
||||
FetchContent_MakeAvailable(torch)
|
||||
list(APPEND CMAKE_PREFIX_PATH ${torch_SOURCE_DIR})
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
set(TGI_COMMON_HEADERS include/common/device.hpp)
|
||||
set(TGI_COMMON_SOURCES lib/device.cpp)
|
||||
|
||||
add_library(tgi_common SHARED ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES})
|
||||
target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES})
|
||||
add_library(tgi_common STATIC ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES})
|
||||
target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES} CUDA::nvml)
|
||||
|
||||
target_include_directories(tgi_common PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/common>
|
||||
|
|
|
@ -7,8 +7,12 @@
|
|||
#include <cstdint>
|
||||
#include <nvml.h>
|
||||
#include <optional>
|
||||
#include <mutex>
|
||||
|
||||
namespace huggingface::tgi {
|
||||
static std::once_flag NVML_INIT_FLAG;
|
||||
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
|
||||
|
||||
using device_index_t = uint8_t;
|
||||
|
||||
/**
|
||||
|
|
|
@ -1,20 +1,48 @@
|
|||
//
|
||||
// Created by morgan on 27/09/24.
|
||||
//
|
||||
#include <fmt/format.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <nvml.h>
|
||||
#include "device.hpp"
|
||||
|
||||
std::optional<nvmlDevice_t> huggingface::tgi::GetDeviceByIndex(device_index_t device)
|
||||
std::optional<nvmlDevice_t> huggingface::tgi::GetDeviceByIndex(device_index_t index)
|
||||
{
|
||||
ENSURE_NVML_INIT();
|
||||
|
||||
nvmlDevice_t device;
|
||||
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
|
||||
return std::optional{ device };
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to)
|
||||
{
|
||||
ENSURE_NVML_INIT();
|
||||
|
||||
// Get devices
|
||||
const auto devFrom = GetDeviceByIndex(from);
|
||||
const auto devTo = GetDeviceByIndex(to);
|
||||
|
||||
if(!devFrom.has_value())
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
|
||||
|
||||
if(!devTo.has_value())
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
|
||||
|
||||
// Query link between both
|
||||
nvmlGpuP2PStatus_t status;
|
||||
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
|
||||
{
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
|
||||
return false;
|
||||
}
|
||||
|
||||
return status == NVML_P2P_STATUS_OK;
|
||||
}
|
||||
|
||||
bool huggingface::tgi::IsP2PComplete()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -3,9 +3,7 @@ project(tgiccl LANGUAGES C CXX CUDA)
|
|||
set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp)
|
||||
set(TGICCL_SOURCES TgiCclBackend.cpp)
|
||||
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES})
|
||||
add_library(tgiccl STATIC ${TGICCL_HEADERS} ${TGICCL_SOURCES})
|
||||
target_link_libraries(tgiccl PUBLIC tgi_common fmt::fmt spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES})
|
||||
|
||||
add_executable(test_tgiccl test_tgiccl.cpp)
|
||||
|
|
|
@ -5,7 +5,42 @@
|
|||
#include "TgiCclBackend.hpp"
|
||||
|
||||
|
||||
auto fmt::formatter<c10d::ReduceOp>::format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator {
|
||||
string_view name = "unknown";
|
||||
switch (op) {
|
||||
case c10d::ReduceOp::AVG: name = "ReduceOp::AVG"; break;
|
||||
case c10d::ReduceOp::BAND: name = "ReduceOp::BAND"; break;
|
||||
case c10d::ReduceOp::BOR: name = "ReduceOp::BOR"; break;
|
||||
case c10d::ReduceOp::BXOR: name = "ReduceOp::BXOR"; break;
|
||||
case c10d::ReduceOp::MAX: name = "ReduceOp::MAX"; break;
|
||||
case c10d::ReduceOp::MIN: name = "ReduceOp::MIN"; break;
|
||||
case c10d::ReduceOp::PREMUL_SUM: name = "ReduceOp::PREMUL_SUM"; break;
|
||||
case c10d::ReduceOp::PRODUCT: name = "ReduceOp::PRODUCT"; break;
|
||||
case c10d::ReduceOp::SUM: name = "ReduceOp::SUM"; break;
|
||||
case c10d::ReduceOp::UNUSED: name = "ReduceOp::UNUSED"; break;
|
||||
}
|
||||
return formatter<string_view>::format(name, ctx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void huggingface::tgi::tgiccl::InitTgiCcl()
|
||||
{
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
huggingface::tgi::tgiccl::TgiCclBackend::TgiCclBackend(const int rank, const int size) : Backend(rank, size) {
|
||||
SPDLOG_INFO(FMT_STRING("Creating {} on rank {:d} (world_size={:d})"), getBackendName(), rank, size);
|
||||
|
||||
}
|
||||
|
||||
const std::string huggingface::tgi::tgiccl::TgiCclBackend::getBackendName() const {
|
||||
return CCL_BACKEND_NAME;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10d::Work>
|
||||
huggingface::tgi::tgiccl::TgiCclBackend::allreduce(std::vector<at::Tensor> &tensors, const c10d::AllreduceOptions &options) {
|
||||
TORCH_CHECK(options.reduceOp == c10d::ReduceOp::SUM, fmt::format(FMT_STRING("tgiccl only supports ReduceOp::SUM, got {}"), options.reduceOp))
|
||||
tensors[0] += 1;
|
||||
return c10::make_intrusive<c10d::Work>();
|
||||
}
|
||||
|
|
|
@ -9,17 +9,30 @@
|
|||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
|
||||
|
||||
template <> struct fmt::formatter<c10d::ReduceOp>: formatter<string_view> {
|
||||
auto format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator;
|
||||
};
|
||||
|
||||
|
||||
|
||||
namespace huggingface::tgi::tgiccl
|
||||
{
|
||||
#define CCL_BACKEND_NAME "tgiccl";
|
||||
|
||||
void InitTgiCcl();
|
||||
|
||||
class TgiCclBackend;
|
||||
class TgiCclBackendWork final: c10d::Work {
|
||||
friend TgiCclBackend;
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
class TgiCclBackend final : c10d::Backend {
|
||||
public:
|
||||
TgiCclBackend(const int rank, const int size): Backend(rank, size)
|
||||
{
|
||||
SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size);
|
||||
}
|
||||
|
||||
TgiCclBackend(int rank, int size);
|
||||
const std::string getBackendName() const override;
|
||||
c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -2,10 +2,14 @@
|
|||
// Created by morgan on 26/09/24.
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <torch/torch.h>
|
||||
#include "tgiccl.hpp"
|
||||
|
||||
int main() {
|
||||
auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1);
|
||||
auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2);
|
||||
auto d = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3);
|
||||
auto backend = huggingface::tgi::tgiccl::TgiCclBackend(0, 4);
|
||||
auto tensor = torch::zeros({128});
|
||||
auto tensors = std::vector<torch::Tensor>();
|
||||
tensors.push_back(tensor);
|
||||
backend.allreduce(tensors, c10d::AllreduceOptions());
|
||||
}
|
|
@ -11,48 +11,8 @@
|
|||
|
||||
#include "TgiCclBackend.hpp"
|
||||
|
||||
constexpr auto CLL_BACKEND_NAME = "tgiccl";
|
||||
|
||||
namespace huggingface::tgi::tgiccl
|
||||
{
|
||||
static std::once_flag NVML_INIT_FLAG;
|
||||
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
|
||||
|
||||
inline std::optional<nvmlDevice_t> GetDeviceByIndex(const size_t index)
|
||||
{
|
||||
ENSURE_NVML_INIT();
|
||||
|
||||
nvmlDevice_t device;
|
||||
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
|
||||
return std::optional{ device };
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
inline bool IsNvLinkAvailable(const int from, const int to)
|
||||
{
|
||||
ENSURE_NVML_INIT();
|
||||
|
||||
// Get devices
|
||||
const auto devFrom = GetDeviceByIndex(from);
|
||||
const auto devTo = GetDeviceByIndex(to);
|
||||
|
||||
if(!devFrom.has_value())
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
|
||||
|
||||
if(!devTo.has_value())
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
|
||||
|
||||
// Query link between both
|
||||
nvmlGpuP2PStatus_t status;
|
||||
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
|
||||
{
|
||||
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
|
||||
return false;
|
||||
}
|
||||
|
||||
return status == NVML_P2P_STATUS_OK;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue