Skip to content

Commit 5a5e232

Browse files
thcmbsGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Add triton support test for recv & friends
PiperOrigin-RevId: 763370346
1 parent f5d441d commit 5a5e232

File tree

2 files changed

+65
-8
lines changed

2 files changed

+65
-8
lines changed

xla/backends/gpu/codegen/triton/support.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -656,13 +656,9 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) {
656656
case HloOpcode::kGather:
657657
case HloOpcode::kPad:
658658
case HloOpcode::kRaggedDot:
659-
case HloOpcode::kRecv:
660-
case HloOpcode::kRecvDone:
661659
case HloOpcode::kReduceWindow:
662660
case HloOpcode::kScatter:
663661
case HloOpcode::kSelectAndScatter:
664-
case HloOpcode::kSend:
665-
case HloOpcode::kSendDone:
666662
case HloOpcode::kSetDimensionSize:
667663
case HloOpcode::kSort:
668664
return true;

xla/backends/gpu/codegen/triton/support_test.cc

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2907,6 +2907,69 @@ INSTANTIATE_TEST_SUITE_P(
29072907
::testing::ValuesIn(AllDevicesToTest())),
29082908
TritonSupportTestTypeAndDeviceToString);
29092909

2910+
using RecvOpsTest = TritonSupportTestWithTypeAndDeviceParam;
2911+
2912+
TEST_P(RecvOpsTest, RecvAndRecvDone) {
2913+
auto [data_type, cc] = GetParam();
2914+
const std::string kHloTestTemplate = R"(
2915+
ENTRY triton_computation {
2916+
token0 = token[] after-all()
2917+
recv_op = ($0[10,20], u32[], token[]) recv(token0), channel_id=15
2918+
recv_done_op = ($0[10,20], token[]) recv-done(recv_op), channel_id=15
2919+
ROOT result = $0[10,20] get-tuple-element(recv_done_op), index=0
2920+
})";
2921+
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti_recv,
2922+
ParseTemplateAndGetInstruction(
2923+
kHloTestTemplate, data_type, HloOpcode::kRecv));
2924+
RunSupportTest(std::move(ti_recv), /*output_tile_sizes=*/{1, 1}, cc);
2925+
2926+
TF_ASSERT_OK_AND_ASSIGN(
2927+
TestedInstruction ti_recv_done,
2928+
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type,
2929+
HloOpcode::kRecvDone));
2930+
RunSupportTest(std::move(ti_recv_done), /*output_tile_sizes=*/{1, 1}, cc);
2931+
}
2932+
2933+
constexpr std::array kTestedOpsRecv = {HloOpcode::kRecv, HloOpcode::kRecvDone};
2934+
2935+
INSTANTIATE_TEST_SUITE_P(
2936+
RecvOpsSuite, RecvOpsTest,
2937+
::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()),
2938+
::testing::ValuesIn(AllDevicesToTest())),
2939+
TritonSupportTestTypeAndDeviceToString);
2940+
2941+
using SendOpsTest = TritonSupportTestWithTypeAndDeviceParam;
2942+
2943+
TEST_P(SendOpsTest, SendAndSendDone) {
2944+
auto [data_type, cc] = GetParam();
2945+
const std::string kHloTestTemplate = R"(
2946+
ENTRY triton_computation {
2947+
data = $0[10] parameter(0)
2948+
token0 = token[] after-all()
2949+
send_op = ($0[10], u32[], token[]) send(data, token0), channel_id=77
2950+
ROOT send_done_op = token[] send-done(send_op), channel_id=77
2951+
})";
2952+
2953+
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti_send,
2954+
ParseTemplateAndGetInstruction(
2955+
kHloTestTemplate, data_type, HloOpcode::kSend));
2956+
RunSupportTest(std::move(ti_send), /*output_tile_sizes=*/{}, cc);
2957+
2958+
TF_ASSERT_OK_AND_ASSIGN(
2959+
TestedInstruction ti_send_done,
2960+
ParseTemplateAndGetInstruction(kHloTestTemplate, data_type,
2961+
HloOpcode::kSendDone));
2962+
RunSupportTest(std::move(ti_send_done), /*output_tile_sizes=*/{}, cc);
2963+
}
2964+
2965+
constexpr std::array kTestedOpsSend = {HloOpcode::kSend, HloOpcode::kSendDone};
2966+
2967+
INSTANTIATE_TEST_SUITE_P(
2968+
SendOpsSuite, SendOpsTest,
2969+
::testing::Combine(::testing::ValuesIn(AllXlaDataTypes()),
2970+
::testing::ValuesIn(AllDevicesToTest())),
2971+
TritonSupportTestTypeAndDeviceToString);
2972+
29102973
class StochasticConvertTest
29112974
: public TritonSupportTest,
29122975
public ::testing::WithParamInterface<
@@ -2996,13 +3059,9 @@ constexpr std::array kUnsupportedOps = {
29963059
HloOpcode::kGather,
29973060
HloOpcode::kPad,
29983061
HloOpcode::kRaggedDot,
2999-
HloOpcode::kRecv,
3000-
HloOpcode::kRecvDone,
30013062
HloOpcode::kReduceWindow,
30023063
HloOpcode::kScatter,
30033064
HloOpcode::kSelectAndScatter,
3004-
HloOpcode::kSend,
3005-
HloOpcode::kSendDone,
30063065
HloOpcode::kSetDimensionSize,
30073066
HloOpcode::kSort,
30083067
// go/keep-sorted end
@@ -3030,6 +3089,8 @@ absl::flat_hash_set<HloOpcode> AllTestedOpcodes() {
30303089
ret.insert(kTestedOpsIota.begin(), kTestedOpsIota.end());
30313090
ret.insert(kTestedOpsRng.begin(), kTestedOpsRng.end());
30323091
ret.insert(kTestedOpsCopy.begin(), kTestedOpsCopy.end());
3092+
ret.insert(kTestedOpsRecv.begin(), kTestedOpsRecv.end());
3093+
ret.insert(kTestedOpsSend.begin(), kTestedOpsSend.end());
30333094

30343095
ret.emplace(HloOpcode::kAfterAll);
30353096
ret.emplace(HloOpcode::kAddDependency);

0 commit comments

Comments
 (0)