Skip to content

[wip][base on new api]Add xccl comm split #1793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: xccl/new_api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 111 additions & 23 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,26 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
return logPrefix_;
}

const std::vector<uint64_t>& ProcessGroupXCCL::groupRanks() const {
if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
static std::vector<uint64_t> globalRanks(size_);
std::iota(globalRanks.begin(), globalRanks.end(), 0);
return globalRanks;
}
return options_->global_ranks_in_group;
}

ProcessGroupXCCL::ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
c10::intrusive_ptr<Store> store,
int rank,
int size)
int size,
c10::intrusive_ptr<Options> options)
: Backend(rank, size),
store_(store),
store_(std::move(store)),
options_(std::move(options)),
xcclCommCounter_(0),
local_id_(process_group_id++) {
this->setGroupUid(options_->group_name);
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
init();
Expand All @@ -345,7 +357,12 @@ ProcessGroupXCCL::ProcessGroupXCCL(
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
const auto XcclVersion = getXcclVersion();
LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: "
<< "size: " << size << ", global rank: " << rank_;
<< "size: " << size << ", global rank: " << rank_
<< ", USE_HIGH_PRIORITY_STREAM: "
<< options_->is_high_priority_stream
<< ", SPLIT_FROM: " << options_->split_from
<< ", SPLIT_COLOR: " << options_->split_color
<< ", PG Name: " << options_->group_name;

LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: "
<< "XCCL version: " << XcclVersion
Expand All @@ -361,6 +378,38 @@ uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() {
return seqCollective_;
}

void ProcessGroupXCCL::eagerConnectSingleDevice(at::Device device) {
const auto key = std::to_string(device.index());
LOG(INFO) << logPrefix() << "Eagerly connecting xccl backend with device "
<< device;
getXCCLComm(key, device, OpType::ALLREDUCE);
}

void ProcessGroupXCCL::performNocolorSplit(at::Device device) {
const auto key = std::to_string(device.index());
LOG(INFO) << logPrefix() << "Performing nocolor split on backend device "
<< device << ", key " << key << ", i am " << this;
auto comm = getXCCLComm(key, device, OpType::ALLREDUCE);
if (comm == nullptr) {
LOG(ERROR) << logPrefix()
<< "No parent communicator exists for nocolor split";
}
c10::OptionalDeviceGuard gpuGuard(device);
xcclComm_t comm_new = nullptr;

onecclCommSplit(
*comm, ONECCL_SPLIT_NOCOLOR, rank_, &comm_new, &options_->config);
xcclCommSplitCounter_++;
}

bool ProcessGroupXCCL::isInitialized() {
if (devXCCLCommMap_.empty()) {
return false;
}
std::lock_guard<std::mutex> lock(mutex_);
return initialized_;
}

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
at::Device& device,
int rank,
Expand Down Expand Up @@ -447,6 +496,16 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
"the devices are empty ");
}

if (bound_device_id_) {
if (*bound_device_id_ != device) {
LOG(ERROR) << logPrefix() << "Tensor found on device " << device
<< " but backend constrained to " << *bound_device_id_;
C10_THROW_ERROR(
DistBackendError,
"Attempt to perform collective on tensor not on device passed to init_process_group");
}
}

usedDeviceIdxs_.insert(device.index());

{
Expand Down Expand Up @@ -481,26 +540,7 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
rank = p2pRank;
}

c10::impl::VirtualGuardImpl impl(device.type());
c10::Stream stream =
impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false);

if (rank_ == 0 || (singleP2POp && p2pRank == 0)) {
onecclGetUniqueId(&xcclID);
}
broadcastUniqueXCCLID(&xcclID, singleP2POp, deviceKey, p2pRank);

xcclComm_t comm = nullptr;
onecclResult_t result = onecclSuccess;
result = onecclSetDevice(device.index());
if (result != onecclSuccess) {
std::cerr << "Failed to set device.\n";
}
result = onecclCommInitRank(&comm, numRanks, xcclID, rank);
if (result != onecclSuccess) {
std::cerr << "Failed to initialize communicator.\n";
}
XCCLComm = std::make_shared<xcclComm_t>(comm);

RECORD_PARAM_COMMS(
0, // seq
Expand All @@ -516,6 +556,42 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
-1, // globalRankStride
size_); // worldSize

if (options_->split_from && !singleP2POp) {
std::lock_guard<std::mutex> lock(options_->split_from->mutex_);
auto& other_comms = options_->split_from->devXCCLCommMap_;
auto dit = other_comms.find(std::to_string(device.index()));
if (dit != other_comms.end()) {
auto& parentComm = dit->second;
if (parentComm != nullptr) {
LOG(INFO) << logPrefix() << "Splitting XCCL communicator from "
<< c10::str(static_cast<void*>(*parentComm));
onecclCommSplit(
*parentComm, options_->split_color, rank, &comm, &options_->config);
xcclCommSplitCounter_++;
}
}
}

if (!comm) {
if (rank_ == 0 || (singleP2POp && p2pRank == 0)) {
onecclGetUniqueId(&xcclID);
}
broadcastUniqueXCCLID(&xcclID, singleP2POp, deviceKey, p2pRank);

onecclResult_t result = onecclSuccess;
result = onecclSetDevice(device.index());
if (result != onecclSuccess) {
std::cerr << "Failed to set device.\n";
}
result = onecclCommInitRankConfig(
&comm, numRanks, xcclID, rank, &options_->config);
if (result != onecclSuccess) {
std::cerr << "Failed to initialize communicator.\n";
}
}

XCCLComm = std::make_shared<xcclComm_t>(comm);

for (const auto i : c10::irange(xcclActiveGroupCounter_)) {
(void)i;
onecclGroupStart();
Expand All @@ -528,19 +604,27 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
// in a map alongside the communicator. Similarly, oneCCLv2 also requires
// retaining the SYCL queue pointer for collective operations, so this change
// will be necessary in oneCCLv2 as well.
bool force_high = getCvarBool(TORCH_XCCL_HIGH_PRIORITY, false);
c10::Stream stream = at::xpu::getStreamFromPool(
options_->is_high_priority_stream || force_high);
std::lock_guard<std::mutex> lock(mutex_);
sycl::queue& q = c10::xpu::XPUStream(stream).queue();
devXCCLCommMap_.emplace(deviceKey, XCCLComm);
xcclStreamsMap_.emplace(
deviceKey, std::make_pair(at::xpu::XPUStream(stream), q));
xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent());
initialized_ = true;

LOG(INFO) << logPrefix()
<< "Created XCCL communicator with Key: " << deviceKey;

return XCCLComm;
}

uint64_t ProcessGroupXCCL::getCommSplitCounter() const {
return xcclCommSplitCounter_;
}

void ProcessGroupXCCL::groupStart() {
onecclGroupStart();
++xcclActiveGroupCounter_;
Expand All @@ -551,6 +635,10 @@ void ProcessGroupXCCL::groupEnd() {
--xcclActiveGroupCounter_;
}

ProcessGroupXCCL::Options::Options(bool is_high_priority_stream)
: Backend::Options(XCCL_BACKEND_NAME),
is_high_priority_stream(is_high_priority_stream) {}

static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
void ProcessGroupXCCL::startCoalescing() {
coalescedDevice_.set_index(-1);
Expand Down
53 changes: 50 additions & 3 deletions src/xccl/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include <torch/csrc/distributed/c10d/logger.hpp>
namespace c10d {

static std::vector<std::string> TORCH_XCCL_HIGH_PRIORITY = {
"TORCH_XCCL_HIGH_PRIORITY"};

static std::vector<std::string> TORCH_XCCL_BLOCKING_WAIT = {
"TORCH_XCCL_BLOCKING_WAIT",
"XCCL_BLOCKING_WAIT"};
Expand Down Expand Up @@ -105,21 +108,52 @@ class TORCH_API ProcessGroupXCCL : public Backend {
friend class ProcessGroupXCCL;
};

ProcessGroupXCCL(const c10::intrusive_ptr<Store>& store, int rank, int size);
struct Options : Backend::Options {
explicit Options(bool is_high_priority_stream = false);

static c10::intrusive_ptr<Options> create(
bool is_high_priority_stream = false) {
return c10::make_intrusive<Options>(is_high_priority_stream);
}
bool is_high_priority_stream;

onecclConfig_t config = ONECCL_CONFIG_INITIALIZER;

std::shared_ptr<ProcessGroupXCCL> split_from;
int split_color{ONECCL_SPLIT_NOCOLOR - 1};

std::vector<uint64_t> global_ranks_in_group;
std::string group_name;
};

ProcessGroupXCCL(
c10::intrusive_ptr<Store> store,
int rank,
int size,
c10::intrusive_ptr<Options> options = Options::create());

C10_DEPRECATED ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const std::string& groupName)
: ProcessGroupXCCL(store, rank, size) {}
const std::string& groupName,
c10::intrusive_ptr<Options> options = Options::create())
: ProcessGroupXCCL(store, rank, size, std::move(options)) {}

~ProcessGroupXCCL() override;

c10::intrusive_ptr<Options> getOptions() {
return options_;
}

const std::string getBackendName() const override {
return std::string(XCCL_BACKEND_NAME);
}

bool supportsSplitting() const override {
return true;
}

bool supportsCoalescing() const override {
return true;
}
Expand Down Expand Up @@ -373,12 +407,23 @@ class TORCH_API ProcessGroupXCCL : public Backend {

const std::string& logPrefix() const;

const std::vector<uint64_t>& groupRanks() const;

uint64_t getCommSplitCounter() const;

void eagerConnectSingleDevice(at::Device device) override;

void performNocolorSplit(at::Device device);

bool isInitialized();

protected:
std::unordered_map<std::string, std::pair<at::xpu::XPUStream, sycl::queue>>
xcclStreamsMap_;
std::unordered_map<std::string, at::xpu::XPUEvent> xcclEventsMap_;
std::unordered_map<std::string, std::shared_ptr<xcclComm_t>> devXCCLCommMap_;
c10::intrusive_ptr<Store> store_;
const c10::intrusive_ptr<Options> options_;
uint64_t xcclCommCounter_{0};
std::mutex mutex_;
std::set<int> usedDeviceIdxs_;
Expand All @@ -393,6 +438,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
uint64_t seqP2P_{0};
size_t local_id_;
std::string logPrefix_;
uint64_t xcclCommSplitCounter_{0};
bool initialized_{false};
};
} // namespace c10d

Expand Down
Loading