Skip to content

Commit 9dcce05

Browse files
[XLA:GPU]: Add e2e tests for AllReduce with 8 GPUs
One test for communication between all replica groups, and second test for communication between 2 replica groups. PiperOrigin-RevId: 763509170
1 parent 825120e commit 9dcce05

File tree

2 files changed

+156
-2
lines changed

2 files changed

+156
-2
lines changed

xla/tests/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2853,9 +2853,9 @@ xla_test(
28532853
"//xla:literal",
28542854
"//xla:literal_util",
28552855
"//xla:types",
2856+
"//xla:xla_data_proto_cc",
28562857
"//xla/hlo/ir:hlo",
28572858
"//xla/hlo/utils:hlo_matchers",
2858-
"//xla/pjrt/gpu:gpu_helpers",
28592859
"//xla/service:computation_placer_hdr",
28602860
"//xla/service:hlo_module_config",
28612861
"//xla/service:hlo_runner_interface",

xla/tests/collective_ops_e2e_test.cc

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ limitations under the License.
4141
#include "xla/hlo/utils/hlo_matchers.h"
4242
#include "xla/literal.h"
4343
#include "xla/literal_util.h"
44-
#include "xla/pjrt/gpu/gpu_helpers.h"
4544
#include "xla/service/computation_placer.h"
4645
#include "xla/service/gpu/backend_configs.pb.h"
4746
#include "xla/service/hlo_module_config.h"
@@ -58,6 +57,7 @@ limitations under the License.
5857
#include "xla/tsl/platform/statusor.h"
5958
#include "xla/tsl/platform/test.h"
6059
#include "xla/types.h"
60+
#include "xla/xla_data.pb.h"
6161

6262
namespace xla {
6363
namespace {
@@ -3209,6 +3209,19 @@ class AllReduceTest
32093209
: public CollectiveOpsWithFlagsBase,
32103210
public ::testing::WithParamInterface<std::tuple<bool, bool>> {
32113211
public:
3212+
struct InputsOutputs {
3213+
std::vector<Literal> inputs;
3214+
std::vector<Literal> expected_outputs;
3215+
3216+
[[nodiscard]] std::vector<std::vector<Literal*>> InputLiteralPtrs() {
3217+
std::vector<std::vector<Literal*>> result;
3218+
for (auto& input : inputs) {
3219+
result.push_back(std::vector<Literal*>{&input});
3220+
}
3221+
return result;
3222+
}
3223+
};
3224+
32123225
AllReduceTest()
32133226
: CollectiveOpsWithFlagsBase(std::get<0>(GetParam()),
32143227
/*enable_p2p_memcpy=*/false) {}
@@ -3222,6 +3235,53 @@ class AllReduceTest
32223235

32233236
return opts;
32243237
}
3238+
3239+
static absl::StatusOr<InputsOutputs> BuildTestInputsOutputs(
3240+
const HloModule& module, int64_t num_replicas) {
3241+
std::vector<Array<float>> inputs;
3242+
std::vector<Literal> input_literals;
3243+
const int64_t num_elements =
3244+
module.entry_computation()->root_instruction()->shape().dimensions()[0];
3245+
for (int i = 0; i < num_replicas; ++i) {
3246+
auto& input = inputs.emplace_back(Array<float>({num_elements}));
3247+
input.FillRandom(1.0f, 10.0f, /*seed=*/i);
3248+
input_literals.push_back(LiteralUtil::CreateFromArray(input));
3249+
}
3250+
std::vector<Array<float>> expected_outputs;
3251+
std::vector<Literal> expected_output_literals;
3252+
const HloInstruction* const instr =
3253+
FindInstruction(&module, HloOpcode::kAllReduce);
3254+
if (instr == nullptr) {
3255+
return absl::InvalidArgumentError(
3256+
"Instruction 'all-reduce' not found in module.");
3257+
}
3258+
const std::vector<ReplicaGroup>& replica_groups =
3259+
instr->device_list().replica_groups();
3260+
// Map each device to set of replica groups it belongs to.
3261+
std::vector<std::vector<int64_t>> device_to_groups(num_replicas);
3262+
for (const auto& replica_group : replica_groups) {
3263+
const auto& replica_ids = replica_group.replica_ids();
3264+
for (int64_t replica : replica_group.replica_ids()) {
3265+
CHECK_EQ(device_to_groups[replica].size(), 0);
3266+
device_to_groups[replica].assign(replica_ids.begin(),
3267+
replica_ids.end());
3268+
}
3269+
}
3270+
for (int i = 0; i < num_replicas; ++i) {
3271+
auto& expected_output =
3272+
expected_outputs.emplace_back(Array<float>({num_elements}));
3273+
// Sum inputs from each replica group.
3274+
expected_output.Each([&](absl::Span<const int64_t> indices, float* val) {
3275+
for (const int64_t replica : device_to_groups[i]) {
3276+
*val += inputs[replica](indices);
3277+
}
3278+
});
3279+
expected_output_literals.push_back(
3280+
LiteralUtil::CreateFromArray(expected_output));
3281+
}
3282+
return InputsOutputs{std::move(input_literals),
3283+
std::move(expected_output_literals)};
3284+
}
32253285
};
32263286

32273287
TEST_P(AllReduceTest, AsyncAllReduce_F32_2GPUs) {
@@ -3336,6 +3396,100 @@ TEST_P(AllReduceTest, AsyncAllReduce_BF16_2GPUs) {
33363396
EXPECT_TRUE(LiteralTestUtil::Equal(expected_output_literal, results[1]));
33373397
}
33383398

3399+
TEST_P(AllReduceTest, AsyncAllReduce_8GPUs_AllReplicasOneGroup) {
3400+
const absl::string_view kModuleStr = R"(
3401+
HloModule test
3402+
3403+
apply_op {
3404+
x = f32[] parameter(0)
3405+
y = f32[] parameter(1)
3406+
ROOT apply_op = f32[] add(x, y)
3407+
}
3408+
3409+
ENTRY test_computation {
3410+
param_0 = f32[65536] parameter(0)
3411+
ROOT all-reduce = f32[65536] all-reduce(param_0), to_apply=apply_op,
3412+
replica_groups={{0,1,2,3,4,5,6,7}}
3413+
}
3414+
)";
3415+
3416+
const int64_t kNumReplicas = 8;
3417+
if (test_runner().device_count() < kNumReplicas) {
3418+
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
3419+
<< test_runner().device_count() << " available)";
3420+
}
3421+
3422+
HloModuleConfig config =
3423+
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
3424+
3425+
TF_ASSERT_OK_AND_ASSIGN(auto module,
3426+
ParseAndReturnVerifiedModule(kModuleStr, config));
3427+
TF_ASSERT_OK_AND_ASSIGN(InputsOutputs test_io,
3428+
BuildTestInputsOutputs(*module, kNumReplicas));
3429+
3430+
TF_ASSERT_OK_AND_ASSIGN(
3431+
std::vector<Literal> results,
3432+
HloTestBase::ExecuteReplicated(std::move(module),
3433+
/*arguments=*/test_io.InputLiteralPtrs(),
3434+
/*num_replicas=*/kNumReplicas,
3435+
/*run_hlo_passes=*/true,
3436+
/*device_assignment=*/nullptr));
3437+
ASSERT_EQ(results.size(), kNumReplicas);
3438+
for (int i = 0; i < kNumReplicas; ++i) {
3439+
// NB: nccl accumulation order can be different from expected calculations
3440+
// leading to differences in the results (floating point imprecision).
3441+
ASSERT_TRUE(LiteralTestUtil::Near(test_io.expected_outputs[i], results[i],
3442+
ErrorSpec{1e-4}))
3443+
<< "ExpectedOutput != Result at index " << i;
3444+
}
3445+
}
3446+
3447+
TEST_P(AllReduceTest, AsyncAllReduce_8GPUs_2ReplicasPerGroup) {
3448+
const absl::string_view kModuleStr = R"(
3449+
HloModule test
3450+
3451+
apply_op {
3452+
x = f32[] parameter(0)
3453+
y = f32[] parameter(1)
3454+
ROOT apply_op = f32[] add(x, y)
3455+
}
3456+
3457+
ENTRY test_computation {
3458+
param_0 = f32[65536] parameter(0)
3459+
ROOT all-reduce = f32[65536] all-reduce(param_0), to_apply=apply_op,
3460+
replica_groups={{0,4},{1,5},{2,6},{3,7}}
3461+
}
3462+
)";
3463+
3464+
const int64_t kNumReplicas = 8;
3465+
if (test_runner().device_count() < kNumReplicas) {
3466+
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
3467+
<< test_runner().device_count() << " available)";
3468+
}
3469+
3470+
HloModuleConfig config =
3471+
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
3472+
3473+
TF_ASSERT_OK_AND_ASSIGN(auto module,
3474+
ParseAndReturnVerifiedModule(kModuleStr, config));
3475+
3476+
TF_ASSERT_OK_AND_ASSIGN(InputsOutputs test_io,
3477+
BuildTestInputsOutputs(*module, kNumReplicas));
3478+
3479+
TF_ASSERT_OK_AND_ASSIGN(
3480+
std::vector<Literal> results,
3481+
HloTestBase::ExecuteReplicated(std::move(module),
3482+
/*arguments=*/test_io.InputLiteralPtrs(),
3483+
/*num_replicas=*/kNumReplicas,
3484+
/*run_hlo_passes=*/true,
3485+
/*device_assignment=*/nullptr));
3486+
ASSERT_EQ(results.size(), kNumReplicas);
3487+
for (int i = 0; i < kNumReplicas; ++i) {
3488+
ASSERT_TRUE(LiteralTestUtil::Equal(test_io.expected_outputs[i], results[i]))
3489+
<< "ExpectedOutput != Result at index " << i;
3490+
}
3491+
}
3492+
33393493
INSTANTIATE_TEST_SUITE_P(
33403494
AllReduceTest, AllReduceTest,
33413495
::testing::Combine(::testing::Bool(), ::testing::Bool()),

0 commit comments

Comments
 (0)