@@ -2907,6 +2907,69 @@ INSTANTIATE_TEST_SUITE_P(
2907
2907
::testing::ValuesIn(AllDevicesToTest())),
2908
2908
TritonSupportTestTypeAndDeviceToString);
2909
2909
2910
+ using RecvOpsTest = TritonSupportTestWithTypeAndDeviceParam;
2911
+
2912
+ TEST_P (RecvOpsTest, RecvAndRecvDone) {
2913
+ auto [data_type, cc] = GetParam ();
2914
+ const std::string kHloTestTemplate = R"(
2915
+ ENTRY triton_computation {
2916
+ token0 = token[] after-all()
2917
+ recv_op = ($0[10,20], u32[], token[]) recv(token0), channel_id=15
2918
+ recv_done_op = ($0[10,20], token[]) recv-done(recv_op), channel_id=15
2919
+ ROOT result = $0[10,20] get-tuple-element(recv_done_op), index=0
2920
+ })" ;
2921
+ TF_ASSERT_OK_AND_ASSIGN (TestedInstruction ti_recv,
2922
+ ParseTemplateAndGetInstruction (
2923
+ kHloTestTemplate , data_type, HloOpcode::kRecv ));
2924
+ RunSupportTest (std::move (ti_recv), /* output_tile_sizes=*/ {1 , 1 }, cc);
2925
+
2926
+ TF_ASSERT_OK_AND_ASSIGN (
2927
+ TestedInstruction ti_recv_done,
2928
+ ParseTemplateAndGetInstruction (kHloTestTemplate , data_type,
2929
+ HloOpcode::kRecvDone ));
2930
+ RunSupportTest (std::move (ti_recv_done), /* output_tile_sizes=*/ {1 , 1 }, cc);
2931
+ }
2932
+
2933
+ constexpr std::array kTestedOpsRecv = {HloOpcode::kRecv , HloOpcode::kRecvDone };
2934
+
2935
+ INSTANTIATE_TEST_SUITE_P (
2936
+ RecvOpsSuite, RecvOpsTest,
2937
+ ::testing::Combine (::testing::ValuesIn(AllXlaDataTypes()),
2938
+ ::testing::ValuesIn(AllDevicesToTest())),
2939
+ TritonSupportTestTypeAndDeviceToString);
2940
+
2941
+ using SendOpsTest = TritonSupportTestWithTypeAndDeviceParam;
2942
+
2943
+ TEST_P (SendOpsTest, SendAndSendDone) {
2944
+ auto [data_type, cc] = GetParam ();
2945
+ const std::string kHloTestTemplate = R"(
2946
+ ENTRY triton_computation {
2947
+ data = $0[10] parameter(0)
2948
+ token0 = token[] after-all()
2949
+ send_op = ($0[10], u32[], token[]) send(data, token0), channel_id=77
2950
+ ROOT send_done_op = token[] send-done(send_op), channel_id=77
2951
+ })" ;
2952
+
2953
+ TF_ASSERT_OK_AND_ASSIGN (TestedInstruction ti_send,
2954
+ ParseTemplateAndGetInstruction (
2955
+ kHloTestTemplate , data_type, HloOpcode::kSend ));
2956
+ RunSupportTest (std::move (ti_send), /* output_tile_sizes=*/ {}, cc);
2957
+
2958
+ TF_ASSERT_OK_AND_ASSIGN (
2959
+ TestedInstruction ti_send_done,
2960
+ ParseTemplateAndGetInstruction (kHloTestTemplate , data_type,
2961
+ HloOpcode::kSendDone ));
2962
+ RunSupportTest (std::move (ti_send_done), /* output_tile_sizes=*/ {}, cc);
2963
+ }
2964
+
2965
+ constexpr std::array kTestedOpsSend = {HloOpcode::kSend , HloOpcode::kSendDone };
2966
+
2967
+ INSTANTIATE_TEST_SUITE_P (
2968
+ SendOpsSuite, SendOpsTest,
2969
+ ::testing::Combine (::testing::ValuesIn(AllXlaDataTypes()),
2970
+ ::testing::ValuesIn(AllDevicesToTest())),
2971
+ TritonSupportTestTypeAndDeviceToString);
2972
+
2910
2973
class StochasticConvertTest
2911
2974
: public TritonSupportTest,
2912
2975
public ::testing::WithParamInterface<
@@ -2996,13 +3059,9 @@ constexpr std::array kUnsupportedOps = {
2996
3059
HloOpcode::kGather ,
2997
3060
HloOpcode::kPad ,
2998
3061
HloOpcode::kRaggedDot ,
2999
- HloOpcode::kRecv ,
3000
- HloOpcode::kRecvDone ,
3001
3062
HloOpcode::kReduceWindow ,
3002
3063
HloOpcode::kScatter ,
3003
3064
HloOpcode::kSelectAndScatter ,
3004
- HloOpcode::kSend ,
3005
- HloOpcode::kSendDone ,
3006
3065
HloOpcode::kSetDimensionSize ,
3007
3066
HloOpcode::kSort ,
3008
3067
// go/keep-sorted end
@@ -3030,6 +3089,8 @@ absl::flat_hash_set<HloOpcode> AllTestedOpcodes() {
3030
3089
ret.insert (kTestedOpsIota .begin (), kTestedOpsIota .end ());
3031
3090
ret.insert (kTestedOpsRng .begin (), kTestedOpsRng .end ());
3032
3091
ret.insert (kTestedOpsCopy .begin (), kTestedOpsCopy .end ());
3092
+ ret.insert (kTestedOpsRecv .begin (), kTestedOpsRecv .end ());
3093
+ ret.insert (kTestedOpsSend .begin (), kTestedOpsSend .end ());
3033
3094
3034
3095
ret.emplace (HloOpcode::kAfterAll );
3035
3096
ret.emplace (HloOpcode::kAddDependency );
0 commit comments