feat(tgi_common) continue more utility functions

This commit is contained in:
Morgan Funtowicz 2024-09-29 12:33:31 +00:00
parent fb81ffce02
commit 513ba5a0b4
10 changed files with 106 additions and 63 deletions

View File

@ -22,6 +22,7 @@ include(cmake/spdlog.cmake)
# Let's find LibTorch # Let's find LibTorch
include(cmake/torch.cmake) include(cmake/torch.cmake)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
find_package(CUDAToolkit REQUIRED)
find_package(Python3 COMPONENTS Interpreter) find_package(Python3 COMPONENTS Interpreter)
# TGI common # TGI common

View File

@ -1,7 +1,7 @@
fetchcontent_declare( fetchcontent_declare(
Torch torch
URL https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip URL https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.4.1%2Bcu124.zip
# OVERRIDE_FIND_PACKAGE # OVERRIDE_FIND_PACKAGE
) )
FetchContent_MakeAvailable(Torch) FetchContent_MakeAvailable(torch)
list(APPEND CMAKE_PREFIX_PATH ${Torch_SOURCE_DIR}) list(APPEND CMAKE_PREFIX_PATH ${torch_SOURCE_DIR})

View File

@ -2,8 +2,8 @@
set(TGI_COMMON_HEADERS include/common/device.hpp) set(TGI_COMMON_HEADERS include/common/device.hpp)
set(TGI_COMMON_SOURCES lib/device.cpp) set(TGI_COMMON_SOURCES lib/device.cpp)
add_library(tgi_common SHARED ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES}) add_library(tgi_common STATIC ${TGI_COMMON_HEADERS} ${TGI_COMMON_SOURCES})
target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES}) target_link_libraries(tgi_common fmt::fmt spdlog::spdlog ${TORCH_LIBRARIES} CUDA::nvml)
target_include_directories(tgi_common PRIVATE target_include_directories(tgi_common PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/common> $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/common>

View File

@ -7,8 +7,12 @@
#include <cstdint> #include <cstdint>
#include <nvml.h> #include <nvml.h>
#include <optional> #include <optional>
#include <mutex>
namespace huggingface::tgi { namespace huggingface::tgi {
static std::once_flag NVML_INIT_FLAG;
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
using device_index_t = uint8_t; using device_index_t = uint8_t;
/** /**

View File

@ -1,20 +1,48 @@
// //
// Created by morgan on 27/09/24. // Created by morgan on 27/09/24.
// //
#include <fmt/format.h>
#include <spdlog/spdlog.h>
#include <nvml.h> #include <nvml.h>
#include "device.hpp" #include "device.hpp"
std::optional<nvmlDevice_t> huggingface::tgi::GetDeviceByIndex(device_index_t device) std::optional<nvmlDevice_t> huggingface::tgi::GetDeviceByIndex(device_index_t index)
{ {
ENSURE_NVML_INIT();
nvmlDevice_t device;
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
return std::optional{ device };
return std::nullopt; return std::nullopt;
} }
bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to)
{
ENSURE_NVML_INIT();
// Get devices
const auto devFrom = GetDeviceByIndex(from);
const auto devTo = GetDeviceByIndex(to);
if(!devFrom.has_value())
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
if(!devTo.has_value())
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
// Query link between both
nvmlGpuP2PStatus_t status;
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
{
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
return false;
}
return status == NVML_P2P_STATUS_OK;
}
bool huggingface::tgi::IsP2PComplete() bool huggingface::tgi::IsP2PComplete()
{ {
return false; return false;
} }
bool huggingface::tgi::IsP2PAvailable(device_index_t from, device_index_t to)
{
return false;
}

View File

@ -3,9 +3,7 @@ project(tgiccl LANGUAGES C CXX CUDA)
set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp) set(TGICCL_HEADERS tgiccl.hpp TgiCclBackend.hpp)
set(TGICCL_SOURCES TgiCclBackend.cpp) set(TGICCL_SOURCES TgiCclBackend.cpp)
find_package(CUDAToolkit REQUIRED) add_library(tgiccl STATIC ${TGICCL_HEADERS} ${TGICCL_SOURCES})
add_library(tgiccl SHARED ${TGICCL_HEADERS} ${TGICCL_SOURCES})
target_link_libraries(tgiccl PUBLIC tgi_common fmt::fmt spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES}) target_link_libraries(tgiccl PUBLIC tgi_common fmt::fmt spdlog::spdlog CUDA::nvml ${TORCH_LIBRARIES})
add_executable(test_tgiccl test_tgiccl.cpp) add_executable(test_tgiccl test_tgiccl.cpp)

View File

@ -5,7 +5,42 @@
#include "TgiCclBackend.hpp" #include "TgiCclBackend.hpp"
auto fmt::formatter<c10d::ReduceOp>::format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator {
string_view name = "unknown";
switch (op) {
case c10d::ReduceOp::AVG: name = "ReduceOp::AVG"; break;
case c10d::ReduceOp::BAND: name = "ReduceOp::BAND"; break;
case c10d::ReduceOp::BOR: name = "ReduceOp::BOR"; break;
case c10d::ReduceOp::BXOR: name = "ReduceOp::BXOR"; break;
case c10d::ReduceOp::MAX: name = "ReduceOp::MAX"; break;
case c10d::ReduceOp::MIN: name = "ReduceOp::MIN"; break;
case c10d::ReduceOp::PREMUL_SUM: name = "ReduceOp::PREMUL_SUM"; break;
case c10d::ReduceOp::PRODUCT: name = "ReduceOp::PRODUCT"; break;
case c10d::ReduceOp::SUM: name = "ReduceOp::SUM"; break;
case c10d::ReduceOp::UNUSED: name = "ReduceOp::UNUSED"; break;
}
return formatter<string_view>::format(name, ctx);
}
void huggingface::tgi::tgiccl::InitTgiCcl() void huggingface::tgi::tgiccl::InitTgiCcl()
{ {
} }
huggingface::tgi::tgiccl::TgiCclBackend::TgiCclBackend(const int rank, const int size) : Backend(rank, size) {
SPDLOG_INFO(FMT_STRING("Creating {} on rank {:d} (world_size={:d})"), getBackendName(), rank, size);
}
const std::string huggingface::tgi::tgiccl::TgiCclBackend::getBackendName() const {
return CCL_BACKEND_NAME;
}
c10::intrusive_ptr<c10d::Work>
huggingface::tgi::tgiccl::TgiCclBackend::allreduce(std::vector<at::Tensor> &tensors, const c10d::AllreduceOptions &options) {
TORCH_CHECK(options.reduceOp == c10d::ReduceOp::SUM, fmt::format(FMT_STRING("tgiccl only supports ReduceOp::SUM, got {}"), options.reduceOp))
tensors[0] += 1;
return c10::make_intrusive<c10d::Work>();
}

View File

@ -9,17 +9,30 @@
#include <torch/csrc/distributed/c10d/Backend.hpp> #include <torch/csrc/distributed/c10d/Backend.hpp>
template <> struct fmt::formatter<c10d::ReduceOp>: formatter<string_view> {
auto format(c10d::ReduceOp op, format_context& ctx) const -> format_context::iterator;
};
namespace huggingface::tgi::tgiccl namespace huggingface::tgi::tgiccl
{ {
#define CCL_BACKEND_NAME "tgiccl";
void InitTgiCcl(); void InitTgiCcl();
class TgiCclBackend;
class TgiCclBackendWork final: c10d::Work {
friend TgiCclBackend;
};
class TgiCclBackend final : c10d::Backend { class TgiCclBackend final : c10d::Backend {
public: public:
TgiCclBackend(const int rank, const int size): Backend(rank, size) TgiCclBackend(int rank, int size);
{ const std::string getBackendName() const override;
SPDLOG_INFO(FMT_STRING("Creating TgiCclBackend on rank {:d} over {:d}"), rank, size);
}
c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override; c10::intrusive_ptr<c10d::Work> allreduce(std::vector<at::Tensor>&, const c10d::AllreduceOptions&) override;
}; };
} }

View File

@ -2,10 +2,14 @@
// Created by morgan on 26/09/24. // Created by morgan on 26/09/24.
// //
#include <vector>
#include <torch/torch.h>
#include "tgiccl.hpp" #include "tgiccl.hpp"
int main() { int main() {
auto a = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 1); auto backend = huggingface::tgi::tgiccl::TgiCclBackend(0, 4);
auto b = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 2); auto tensor = torch::zeros({128});
auto d = huggingface::tgi::tgiccl::IsNvLinkAvailable(0, 3); auto tensors = std::vector<torch::Tensor>();
tensors.push_back(tensor);
backend.allreduce(tensors, c10d::AllreduceOptions());
} }

View File

@ -11,48 +11,8 @@
#include "TgiCclBackend.hpp" #include "TgiCclBackend.hpp"
constexpr auto CLL_BACKEND_NAME = "tgiccl";
namespace huggingface::tgi::tgiccl namespace huggingface::tgi::tgiccl
{ {
static std::once_flag NVML_INIT_FLAG;
#define ENSURE_NVML_INIT() std::call_once(NVML_INIT_FLAG, nvmlInit_v2);
inline std::optional<nvmlDevice_t> GetDeviceByIndex(const size_t index)
{
ENSURE_NVML_INIT();
nvmlDevice_t device;
if(nvmlDeviceGetHandleByIndex_v2(index, &device) == NVML_SUCCESS)
return std::optional{ device };
return std::nullopt;
}
inline bool IsNvLinkAvailable(const int from, const int to)
{
ENSURE_NVML_INIT();
// Get devices
const auto devFrom = GetDeviceByIndex(from);
const auto devTo = GetDeviceByIndex(to);
if(!devFrom.has_value())
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), from);
if(!devTo.has_value())
SPDLOG_ERROR(FMT_STRING("Failed to retrieve device at index {:d}"), to);
// Query link between both
nvmlGpuP2PStatus_t status;
if(nvmlDeviceGetP2PStatus(devFrom.value(), devTo.value(), NVML_P2P_CAPS_INDEX_NVLINK, &status) != NVML_SUCCESS)
{
SPDLOG_ERROR(FMT_STRING("Failed to retrieve the p2p status for device {:d} <-> {:d}"), from, to);
return false;
}
return status == NVML_P2P_STATUS_OK;
}
} }