@@ -41,7 +41,6 @@ limitations under the License.
41
41
#include " xla/hlo/utils/hlo_matchers.h"
42
42
#include " xla/literal.h"
43
43
#include " xla/literal_util.h"
44
- #include " xla/pjrt/gpu/gpu_helpers.h"
45
44
#include " xla/service/computation_placer.h"
46
45
#include " xla/service/gpu/backend_configs.pb.h"
47
46
#include " xla/service/hlo_module_config.h"
@@ -58,6 +57,7 @@ limitations under the License.
58
57
#include " xla/tsl/platform/statusor.h"
59
58
#include " xla/tsl/platform/test.h"
60
59
#include " xla/types.h"
60
+ #include " xla/xla_data.pb.h"
61
61
62
62
namespace xla {
63
63
namespace {
@@ -3209,6 +3209,19 @@ class AllReduceTest
3209
3209
: public CollectiveOpsWithFlagsBase,
3210
3210
public ::testing::WithParamInterface<std::tuple<bool , bool >> {
3211
3211
public:
3212
+ struct InputsOutputs {
3213
+ std::vector<Literal> inputs;
3214
+ std::vector<Literal> expected_outputs;
3215
+
3216
+ [[nodiscard]] std::vector<std::vector<Literal*>> InputLiteralPtrs () {
3217
+ std::vector<std::vector<Literal*>> result;
3218
+ for (auto & input : inputs) {
3219
+ result.push_back (std::vector<Literal*>{&input});
3220
+ }
3221
+ return result;
3222
+ }
3223
+ };
3224
+
3212
3225
AllReduceTest ()
3213
3226
: CollectiveOpsWithFlagsBase(std::get<0 >(GetParam()),
3214
3227
/* enable_p2p_memcpy=*/ false ) {}
@@ -3222,6 +3235,53 @@ class AllReduceTest
3222
3235
3223
3236
return opts;
3224
3237
}
3238
+
3239
+ static absl::StatusOr<InputsOutputs> BuildTestInputsOutputs (
3240
+ const HloModule& module , int64_t num_replicas) {
3241
+ std::vector<Array<float >> inputs;
3242
+ std::vector<Literal> input_literals;
3243
+ const int64_t num_elements =
3244
+ module .entry_computation ()->root_instruction ()->shape ().dimensions ()[0 ];
3245
+ for (int i = 0 ; i < num_replicas; ++i) {
3246
+ auto & input = inputs.emplace_back (Array<float >({num_elements}));
3247
+ input.FillRandom (1 .0f , 10 .0f , /* seed=*/ i);
3248
+ input_literals.push_back (LiteralUtil::CreateFromArray (input));
3249
+ }
3250
+ std::vector<Array<float >> expected_outputs;
3251
+ std::vector<Literal> expected_output_literals;
3252
+ const HloInstruction* const instr =
3253
+ FindInstruction (&module , HloOpcode::kAllReduce );
3254
+ if (instr == nullptr ) {
3255
+ return absl::InvalidArgumentError (
3256
+ " Instruction 'all-reduce' not found in module." );
3257
+ }
3258
+ const std::vector<ReplicaGroup>& replica_groups =
3259
+ instr->device_list ().replica_groups ();
3260
+ // Map each device to set of replica groups it belongs to.
3261
+ std::vector<std::vector<int64_t >> device_to_groups (num_replicas);
3262
+ for (const auto & replica_group : replica_groups) {
3263
+ const auto & replica_ids = replica_group.replica_ids ();
3264
+ for (int64_t replica : replica_group.replica_ids ()) {
3265
+ CHECK_EQ (device_to_groups[replica].size (), 0 );
3266
+ device_to_groups[replica].assign (replica_ids.begin (),
3267
+ replica_ids.end ());
3268
+ }
3269
+ }
3270
+ for (int i = 0 ; i < num_replicas; ++i) {
3271
+ auto & expected_output =
3272
+ expected_outputs.emplace_back (Array<float >({num_elements}));
3273
+ // Sum inputs from each replica group.
3274
+ expected_output.Each ([&](absl::Span<const int64_t > indices, float * val) {
3275
+ for (const int64_t replica : device_to_groups[i]) {
3276
+ *val += inputs[replica](indices);
3277
+ }
3278
+ });
3279
+ expected_output_literals.push_back (
3280
+ LiteralUtil::CreateFromArray (expected_output));
3281
+ }
3282
+ return InputsOutputs{std::move (input_literals),
3283
+ std::move (expected_output_literals)};
3284
+ }
3225
3285
};
3226
3286
3227
3287
TEST_P (AllReduceTest, AsyncAllReduce_F32_2GPUs) {
@@ -3336,6 +3396,100 @@ TEST_P(AllReduceTest, AsyncAllReduce_BF16_2GPUs) {
3336
3396
EXPECT_TRUE (LiteralTestUtil::Equal (expected_output_literal, results[1 ]));
3337
3397
}
3338
3398
3399
+ TEST_P (AllReduceTest, AsyncAllReduce_8GPUs_AllReplicasOneGroup) {
3400
+ const absl::string_view kModuleStr = R"(
3401
+ HloModule test
3402
+
3403
+ apply_op {
3404
+ x = f32[] parameter(0)
3405
+ y = f32[] parameter(1)
3406
+ ROOT apply_op = f32[] add(x, y)
3407
+ }
3408
+
3409
+ ENTRY test_computation {
3410
+ param_0 = f32[65536] parameter(0)
3411
+ ROOT all-reduce = f32[65536] all-reduce(param_0), to_apply=apply_op,
3412
+ replica_groups={{0,1,2,3,4,5,6,7}}
3413
+ }
3414
+ )" ;
3415
+
3416
+ const int64_t kNumReplicas = 8 ;
3417
+ if (test_runner ().device_count () < kNumReplicas ) {
3418
+ GTEST_SKIP () << " Test requires at least " << kNumReplicas << " devices ("
3419
+ << test_runner ().device_count () << " available)" ;
3420
+ }
3421
+
3422
+ HloModuleConfig config =
3423
+ GetModuleConfigForTest (/* replica_count=*/ kNumReplicas );
3424
+
3425
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
3426
+ ParseAndReturnVerifiedModule (kModuleStr , config));
3427
+ TF_ASSERT_OK_AND_ASSIGN (InputsOutputs test_io,
3428
+ BuildTestInputsOutputs (*module , kNumReplicas ));
3429
+
3430
+ TF_ASSERT_OK_AND_ASSIGN (
3431
+ std::vector<Literal> results,
3432
+ HloTestBase::ExecuteReplicated (std::move (module ),
3433
+ /* arguments=*/ test_io.InputLiteralPtrs (),
3434
+ /* num_replicas=*/ kNumReplicas ,
3435
+ /* run_hlo_passes=*/ true ,
3436
+ /* device_assignment=*/ nullptr ));
3437
+ ASSERT_EQ (results.size (), kNumReplicas );
3438
+ for (int i = 0 ; i < kNumReplicas ; ++i) {
3439
+ // NB: nccl accumulation order can be different from expected calculations
3440
+ // leading to differences in the results (floating point imprecision).
3441
+ ASSERT_TRUE (LiteralTestUtil::Near (test_io.expected_outputs [i], results[i],
3442
+ ErrorSpec{1e-4 }))
3443
+ << " ExpectedOutput != Result at index " << i;
3444
+ }
3445
+ }
3446
+
3447
+ TEST_P (AllReduceTest, AsyncAllReduce_8GPUs_2ReplicasPerGroup) {
3448
+ const absl::string_view kModuleStr = R"(
3449
+ HloModule test
3450
+
3451
+ apply_op {
3452
+ x = f32[] parameter(0)
3453
+ y = f32[] parameter(1)
3454
+ ROOT apply_op = f32[] add(x, y)
3455
+ }
3456
+
3457
+ ENTRY test_computation {
3458
+ param_0 = f32[65536] parameter(0)
3459
+ ROOT all-reduce = f32[65536] all-reduce(param_0), to_apply=apply_op,
3460
+ replica_groups={{0,4},{1,5},{2,6},{3,7}}
3461
+ }
3462
+ )" ;
3463
+
3464
+ const int64_t kNumReplicas = 8 ;
3465
+ if (test_runner ().device_count () < kNumReplicas ) {
3466
+ GTEST_SKIP () << " Test requires at least " << kNumReplicas << " devices ("
3467
+ << test_runner ().device_count () << " available)" ;
3468
+ }
3469
+
3470
+ HloModuleConfig config =
3471
+ GetModuleConfigForTest (/* replica_count=*/ kNumReplicas );
3472
+
3473
+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
3474
+ ParseAndReturnVerifiedModule (kModuleStr , config));
3475
+
3476
+ TF_ASSERT_OK_AND_ASSIGN (InputsOutputs test_io,
3477
+ BuildTestInputsOutputs (*module , kNumReplicas ));
3478
+
3479
+ TF_ASSERT_OK_AND_ASSIGN (
3480
+ std::vector<Literal> results,
3481
+ HloTestBase::ExecuteReplicated (std::move (module ),
3482
+ /* arguments=*/ test_io.InputLiteralPtrs (),
3483
+ /* num_replicas=*/ kNumReplicas ,
3484
+ /* run_hlo_passes=*/ true ,
3485
+ /* device_assignment=*/ nullptr ));
3486
+ ASSERT_EQ (results.size (), kNumReplicas );
3487
+ for (int i = 0 ; i < kNumReplicas ; ++i) {
3488
+ ASSERT_TRUE (LiteralTestUtil::Equal (test_io.expected_outputs [i], results[i]))
3489
+ << " ExpectedOutput != Result at index " << i;
3490
+ }
3491
+ }
3492
+
3339
3493
INSTANTIATE_TEST_SUITE_P (
3340
3494
AllReduceTest, AllReduceTest,
3341
3495
::testing::Combine (::testing::Bool(), ::testing::Bool()),
0 commit comments