Skip to content

Commit da63f29

Browse files
[XLA:GPU] Add support for all-to-all to perf table gen.
PiperOrigin-RevId: 767489571
1 parent 1c8ee2e commit da63f29

8 files changed

+175
-38
lines changed

xla/service/gpu/model/collective_interpolator.cc

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,51 @@ std::unique_ptr<HloModule> AllGatherModule(
249249
return module;
250250
}
251251

252+
std::unique_ptr<HloModule> AllToAllModule(
253+
const HloInstructionProfile& profile) {
254+
HloModuleConfig config;
255+
auto module = std::make_unique<HloModule>("m", config);
256+
auto shape = Shape::FromProto(profile.instruction().shape());
257+
if (!shape.ok()) {
258+
VLOG(1) << "Cannot parse shape: " << profile.DebugString();
259+
return nullptr;
260+
}
261+
262+
HloComputation::Builder entry_builder("entry");
263+
CollectiveDeviceList collective_device_list(
264+
IotaReplicaGroupList::FromProto(profile.instruction()
265+
.collective_device_list()
266+
.iota_replica_group_list()));
267+
268+
HloInstruction* p0 = entry_builder.AddInstruction(
269+
HloInstruction::CreateParameter(0, *shape, "p0"));
270+
entry_builder.AddInstruction(HloInstruction::CreateAllToAll(
271+
*shape, {p0}, collective_device_list,
272+
profile.instruction().constrain_layout(),
273+
profile.instruction().channel_id(),
274+
profile.instruction().use_global_device_ids()));
275+
module->AddEntryComputation(entry_builder.Build());
276+
return module;
277+
}
278+
279+
std::optional<CollectiveDeviceList> CanonicalDeviceList(
280+
const HloCollectiveInstruction& instr) {
281+
if (instr.device_list().iota_replica_group_list().has_value()) {
282+
return instr.device_list();
283+
}
284+
auto num_groups_and_devices = GetReplicaGroupCountAndSize(&instr);
285+
if (!num_groups_and_devices.ok() || !num_groups_and_devices->has_value()) {
286+
VLOG(1) << "Failed to determine a number of devices participating in "
287+
"the collective: "
288+
<< instr.ToString();
289+
return std::nullopt;
290+
}
291+
292+
IotaReplicaGroupList iota((*num_groups_and_devices)->first,
293+
(*num_groups_and_devices)->second);
294+
return CollectiveDeviceList(iota);
295+
}
296+
252297
HloOpcode AsyncToSyncOpcode(const HloCollectiveInstruction& instr) {
253298
HloOpcode opcode = instr.opcode();
254299
switch (opcode) {
@@ -294,6 +339,17 @@ int64_t GetBytesTransferred(const HloInstruction& instr,
294339
return adhoc.BytesTransferred(instr);
295340
}
296341

342+
bool RequiresAccumulation(HloOpcode opcode) {
343+
switch (opcode) {
344+
case HloOpcode::kAllReduceStart:
345+
case HloOpcode::kAllReduce:
346+
case HloOpcode::kReduceScatter:
347+
return true;
348+
default:
349+
return false;
350+
}
351+
}
352+
297353
absl::StatusOr<std::unique_ptr<
298354
absl::flat_hash_map<CollectiveInterpolator::ExactInterpolatorKey,
299355
std::unique_ptr<InterpolatorBase<int64_t, 1>>>>>
@@ -311,7 +367,9 @@ ConstructExactInterpolators(int num_devices_per_host,
311367
CollectiveInterpolator::ExactInterpolatorKey exact_key{
312368
/*opcode=*/spec.opcode,
313369
/*device_list=*/spec.device_list,
314-
/*data_type=*/spec.data_type,
370+
/*data_type=*/
371+
RequiresAccumulation(spec.opcode) ? std::make_optional(spec.data_type)
372+
: std::nullopt,
315373
};
316374
auto exact_it = exact_interpolators->find(exact_key);
317375
if (exact_it == exact_interpolators->end()) {
@@ -429,17 +487,6 @@ ConstructFallbackNNInterpolators(int num_devices_per_host,
429487
return fallback_interpolators;
430488
}
431489

432-
bool RequiresAccumulation(const HloCollectiveInstruction& instr) {
433-
switch (instr.opcode()) {
434-
case HloOpcode::kAllReduceStart:
435-
case HloOpcode::kAllReduce:
436-
case HloOpcode::kReduceScatter:
437-
return true;
438-
default:
439-
return false;
440-
}
441-
}
442-
443490
} // namespace
444491

445492
// We can get rid of `analysis` being nullptr once we get rid of stats
@@ -488,21 +535,23 @@ std::optional<absl::Duration> CollectiveInterpolator::EstimatedRuntime(
488535
int64_t bytes_transferred =
489536
GetBytesTransferred(instr, device_info_, analysis_);
490537

491-
ExactInterpolatorKey exact_key{
492-
/*opcode=*/instr.opcode(),
493-
/*device_list=*/instr.device_list(),
494-
/*data_type=*/
495-
RequiresAccumulation(instr)
496-
? std::make_optional(instr.shape().element_type())
497-
: std::nullopt,
498-
};
538+
std::optional<CollectiveDeviceList> devices = CanonicalDeviceList(instr);
539+
if (devices.has_value()) {
540+
ExactInterpolatorKey exact_key{
541+
/*opcode=*/instr.opcode(),
542+
/*device_list=*/*devices,
543+
/*data_type=*/
544+
RequiresAccumulation(instr.opcode())
545+
? std::make_optional(instr.shape().element_type())
546+
: std::nullopt,
547+
};
499548

500-
if (exact_interpolators_->contains(exact_key)) {
501-
std::array<int64_t, 1> point({bytes_transferred});
502-
return absl::Seconds(1.0 * bytes_transferred /
503-
exact_interpolators_->at(exact_key)->Eval(point));
549+
if (exact_interpolators_->contains(exact_key)) {
550+
std::array<int64_t, 1> point({bytes_transferred});
551+
return absl::Seconds(1.0 * bytes_transferred /
552+
exact_interpolators_->at(exact_key)->Eval(point));
553+
}
504554
}
505-
506555
// Fallback interpolation.
507556
auto comm = CommunicationType(num_devices_per_host_, instr,
508557
device_info_.gpu_compute_capability());
@@ -537,6 +586,8 @@ std::optional<absl::Duration> CollectiveInterpolator::EstimatedRuntime(
537586
case HloOpcode::kAllGather:
538587
case HloOpcode::kAllGatherStart:
539588
return AllGatherModule(profile);
589+
case HloOpcode::kAllToAll:
590+
return AllToAllModule(profile);
540591
default:
541592
LOG(FATAL) << "Unsupported profile instruction: "
542593
<< profile.DebugString();

xla/service/gpu/model/collective_interpolator_test.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
101101
switch (opcode) {
102102
case HloOpcode::kAllReduce:
103103
case HloOpcode::kAllReduceStart:
104+
case HloOpcode::kAllToAll:
104105
device_list = CollectiveDeviceList(CommToDeviceList(comm, num_hosts));
105106
shape = ShapeUtil::MakeShape(PrimitiveType::F32, {tensor_size / 4});
106107
break;
@@ -419,6 +420,27 @@ class CollectiveInterpolationTest : public TestWithParam<ParametrizedTestCase> {
419420
/*num_nodes=*/4,
420421
/*network_througput_bytes=*/2 * 2048,
421422
},
423+
{
424+
/*opcode=*/HloOpcode::kAllToAll,
425+
/*comm=*/GPUCommunicationType::SINGLE_HOST,
426+
/*tensor_size=*/1024,
427+
/*num_nodes=*/1,
428+
/*network_througput_bytes=*/1024,
429+
},
430+
{
431+
/*opcode=*/HloOpcode::kAllToAll,
432+
/*comm=*/GPUCommunicationType::RAIL_ALIGNED,
433+
/*tensor_size=*/1024,
434+
/*num_nodes=*/2,
435+
/*network_througput_bytes=*/2048,
436+
},
437+
{
438+
/*opcode=*/HloOpcode::kAllToAll,
439+
/*comm=*/GPUCommunicationType::NON_RAIL_ALIGNED,
440+
/*tensor_size=*/1024,
441+
/*num_nodes=*/2,
442+
/*network_througput_bytes=*/4096,
443+
},
422444
};
423445
};
424446

@@ -960,6 +982,39 @@ INSTANTIATE_TEST_SUITE_P(
960982
},
961983
/*expected_duration=*/absl::Milliseconds(625),
962984
},
985+
{
986+
/*test_name=*/"A2A_rail_aligned_exact_match",
987+
{
988+
/*opcode=*/HloOpcode::kAllToAll,
989+
/*comm=*/
990+
GPUCommunicationType::RAIL_ALIGNED,
991+
/*tensor_size=*/1024,
992+
/*num_nodes=*/2,
993+
},
994+
/*expected_duration=*/absl::Milliseconds(500),
995+
},
996+
{
997+
/*test_name=*/"A2A_nonrail_aligned_exact_match",
998+
{
999+
/*opcode=*/HloOpcode::kAllToAll,
1000+
/*comm=*/
1001+
GPUCommunicationType::NON_RAIL_ALIGNED,
1002+
/*tensor_size=*/1024,
1003+
/*num_nodes=*/2,
1004+
},
1005+
/*expected_duration=*/absl::Milliseconds(250),
1006+
},
1007+
{
1008+
/*test_name=*/"A2A_single_host_exact_match",
1009+
{
1010+
/*opcode=*/HloOpcode::kAllToAll,
1011+
/*comm=*/
1012+
GPUCommunicationType::SINGLE_HOST,
1013+
/*tensor_size=*/1024,
1014+
/*num_nodes=*/1,
1015+
},
1016+
/*expected_duration=*/absl::Seconds(1),
1017+
},
9631018
}),
9641019
[](const TestParamInfo<CollectiveInterpolationTest::ParamType>& info) {
9651020
return info.param.test_name;

xla/service/gpu/model/gpu_hlo_cost_analysis.cc

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,13 +533,14 @@ absl::Status GpuHloCostAnalysis::HandleAllGatherStart(
533533

534534
absl::Status GpuHloCostAnalysis::HandleAsyncStart(const HloInstruction* hlo) {
535535
auto* async_start = DynCast<HloAsyncStartInstruction>(hlo);
536-
if (async_start->async_wrapped_opcode() != HloOpcode::kReduceScatter) {
537-
VLOG(2) << "Only Reduce Scatter is supported.";
538-
return absl::OkStatus();
539-
}
540-
541536
TF_RETURN_IF_ERROR(hlo->async_wrapped_instruction()->Accept(this));
542-
return HandleReduceScatter(async_start->async_wrapped_instruction());
537+
if (async_start->async_wrapped_opcode() == HloOpcode::kReduceScatter) {
538+
return HandleReduceScatter(async_start->async_wrapped_instruction());
539+
}
540+
if (async_start->async_wrapped_opcode() == HloOpcode::kAllToAll) {
541+
return HandleAllToAll(async_start->async_wrapped_instruction());
542+
}
543+
return absl::OkStatus();
543544
}
544545

545546
absl::Status GpuHloCostAnalysis::HandleReduceScatter(
@@ -563,6 +564,12 @@ absl::Status GpuHloCostAnalysis::HandleReduceScatter(
563564
return absl::OkStatus();
564565
}
565566

567+
absl::Status GpuHloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) {
568+
int64_t bytes_transferred = ShapeSize(hlo->shape(), options_.shape_size);
569+
current_properties_[kCollBytesTransferred] = bytes_transferred;
570+
return absl::OkStatus();
571+
}
572+
566573
absl::Status GpuHloCostAnalysis::HandleElementwiseOp(
567574
const HloInstruction* hlo) {
568575
current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(hlo);

xla/service/gpu/model/gpu_hlo_cost_analysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class GpuHloCostAnalysis : public HloCostAnalysis {
7878
absl::Status HandleAllGatherStart(const HloInstruction* hlo) override;
7979
absl::Status HandleAsyncStart(const HloInstruction* hlo) override;
8080
absl::Status HandleReduceScatter(const HloInstruction* hlo) override;
81+
absl::Status HandleAllToAll(const HloInstruction* hlo) override;
8182

8283
// Estimate the total size of IR accounting for both duplication
8384
// of producer code by consumer and the total number of basic blocks.

xla/tools/collective_perf_table_gen.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,14 @@ int64_t GetInputDim(CollectivePerfTableGen::CollectiveType type,
121121
CHECK_EQ(tensor_size_bytes % kBytesPerElem, 0);
122122
switch (type) {
123123
case CollectivePerfTableGen::CollectiveType::ALL_REDUCE:
124+
case CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER:
125+
case CollectivePerfTableGen::CollectiveType::ALL_TO_ALL:
124126
dim_size = tensor_size_bytes / kBytesPerElem;
125127
break;
126128
case CollectivePerfTableGen::CollectiveType::ALL_GATHER:
127129
dim_size = tensor_size_bytes /
128130
(kBytesPerElem * replica_groups.num_devices_per_group());
129131
break;
130-
case CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER:
131-
dim_size = tensor_size_bytes / kBytesPerElem;
132-
break;
133132
default:
134133
LOG(FATAL) << "Unsupported collective type.";
135134
}
@@ -144,6 +143,7 @@ int64_t GetOutputDim(CollectivePerfTableGen::CollectiveType type,
144143
switch (type) {
145144
case CollectivePerfTableGen::CollectiveType::ALL_REDUCE:
146145
case CollectivePerfTableGen::CollectiveType::ALL_GATHER:
146+
case CollectivePerfTableGen::CollectiveType::ALL_TO_ALL:
147147
dim_size = tensor_size_bytes / kBytesPerElem;
148148
break;
149149
case CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER:
@@ -215,6 +215,19 @@ std::string GetHlo(CollectivePerfTableGen::CollectiveType type,
215215
"f32", input_dim, output_dim,
216216
replica_groups.ToString());
217217
break;
218+
case CollectivePerfTableGen::CollectiveType::ALL_TO_ALL:
219+
hlo = absl::Substitute(R"(
220+
HloModule m
221+
222+
ENTRY e {
223+
p0 = $0[$1] parameter(0)
224+
ROOT _ = $0[$2] all-to-all(p0), replica_groups=$3, channel_id=1,
225+
dimensions={0}
226+
}
227+
)",
228+
"f32", input_dim, output_dim,
229+
replica_groups.ToString());
230+
break;
218231
default:
219232
LOG(FATAL) << "Unsupported collective type.";
220233
}

xla/tools/collective_perf_table_gen.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class CollectivePerfTableGen {
5353
ALL_REDUCE,
5454
ALL_GATHER,
5555
REDUCE_SCATTER,
56+
ALL_TO_ALL,
5657
};
5758

5859
struct Config {
@@ -64,6 +65,7 @@ class CollectivePerfTableGen {
6465
CollectiveType::ALL_REDUCE,
6566
CollectiveType::ALL_GATHER,
6667
CollectiveType::REDUCE_SCATTER,
68+
CollectiveType::ALL_TO_ALL,
6769
};
6870
std::vector<std::string> replica_groups_list;
6971

xla/tools/collective_perf_table_gen_main.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ std::vector<CollectivePerfTableGen::CollectiveType> ParseCollectives(
106106
types.push_back(CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER);
107107
continue;
108108
}
109+
if (token == "ALL_TO_ALL") {
110+
types.push_back(CollectivePerfTableGen::CollectiveType::ALL_TO_ALL);
111+
continue;
112+
}
109113
}
110114
CHECK_GT(types.size(), 0);
111115
return types;
@@ -160,10 +164,12 @@ int main(int argc, char* argv[]) {
160164
int32_t num_nodes = 1;
161165
int32_t num_devices_per_host = 8;
162166
int32_t task_id = 0;
163-
std::string collectives_unparsed = "ALL_REDUCE,ALL_GATHER,REDUCE_SCATTER";
167+
std::string collectives_unparsed =
168+
"ALL_REDUCE,ALL_GATHER,REDUCE_SCATTER,ALL_TO_ALL";
164169
std::string tensor_size_bytes_spec_unparsed =
165170
"start=1024,stop=2147483648,factor=2";
166-
std::string collective_devices_spec_unparsed;
171+
std::string collective_devices_spec_unparsed =
172+
"[1,8]<=[8];[2,4]<=[8];[4,2]<=[8]";
167173
std::string coordinator_address = std::string(kDefaultCoordinatorAddress);
168174
std::string output = std::string(CollectivePerfTableGen::Config::kStdout);
169175
std::string merge_path;
@@ -179,7 +185,8 @@ int main(int argc, char* argv[]) {
179185
"across the distributed system you run it on."),
180186
tsl::Flag("collectives", &collectives_unparsed,
181187
"Comma separated list of collectives to generate perf table "
182-
"for. Allowed values: ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER."),
188+
"for. Allowed values: ALL_REDUCE, ALL_GATHER, REDUCE_SCATTER, "
189+
"ALL_TO_ALL."),
183190
tsl::Flag("tensor_size_bytes_spec", &tensor_size_bytes_spec_unparsed,
184191
"Spec for a search sweep over transfer sizes. Format example: "
185192
"start=1,stop=8,factor=2 generates {1,2,4,8}."),

xla/tools/collective_perf_table_gen_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) {
8585
CollectivePerfTableGen::CollectiveType::ALL_REDUCE,
8686
CollectivePerfTableGen::CollectiveType::ALL_GATHER,
8787
CollectivePerfTableGen::CollectiveType::REDUCE_SCATTER,
88+
CollectivePerfTableGen::CollectiveType::ALL_TO_ALL,
8889
};
8990
cfg_.replica_groups_list.emplace_back("[1,1]<=[1]");
9091
CollectivePerfTableGen::StepSpec spec{
@@ -100,7 +101,7 @@ TEST_F(CollectivePerfTableGenTest, FactorStepGeneratesConfigs) {
100101

101102
DeviceHloInstructionProfiles profiles = gen->ComputeTable();
102103
EXPECT_EQ(profiles.entries_size(), 1);
103-
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 12);
104+
EXPECT_EQ(profiles.entries().begin()->second.entries_size(), 16);
104105
}
105106

106107
TEST_F(CollectivePerfTableGenTest, HappyPathWorks) {

0 commit comments

Comments
 (0)