@@ -1364,3 +1364,57 @@ attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPU
13641364// CHECK: scf.forall.in_parallel
13651365// CHECK: tensor.parallel_insert_slice %[[RES]] into %[[OUT0]][%[[OFFSET0]], 0, %[[OFFSET1]]]
13661366// CHECK: {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
1367+
1368+ // -----
1369+
1370+ #config_fill_fold = #iree_codegen.lowering_config <tile_sizes = [[1 , 8 ]]>
1371+ func.func @fold_fill_through_block_arg (%arg0 : tensor <4 x16 x128 xf16 >) -> (tensor <4 x16 xf16 >, tensor <4 x16 xi32 >) {
1372+ %cst = arith.constant 0xFC00 : f16
1373+ %c0_i32 = arith.constant 0 : i32
1374+ %c0 = arith.constant 0 : index
1375+ %empty_f16 = tensor.empty () : tensor <4 x16 xf16 >
1376+ %empty_i32 = tensor.empty () : tensor <4 x16 xi32 >
1377+ %fill_f16 = linalg.fill {lowering_config = #config_fill_fold }
1378+ ins (%cst : f16 ) outs (%empty_f16 : tensor <4 x16 xf16 >) -> tensor <4 x16 xf16 >
1379+ %fill_i32 = linalg.fill {lowering_config = #config_fill_fold }
1380+ ins (%c0_i32 : i32 ) outs (%empty_i32 : tensor <4 x16 xi32 >) -> tensor <4 x16 xi32 >
1381+ %result:2 = scf.forall (%iv0 , %iv1 ) = (0 , 0 ) to (4 , 16 ) step (1 , 8 )
1382+ shared_outs (%out_f16 = %fill_f16 , %out_i32 = %fill_i32 ) -> (tensor <4 x16 xf16 >, tensor <4 x16 xi32 >) {
1383+ %in_slice = tensor.extract_slice %arg0 [%iv0 , %iv1 , 0 ] [1 , 8 , 128 ] [1 , 1 , 1 ]
1384+ : tensor <4 x16 x128 xf16 > to tensor <1 x8 x128 xf16 >
1385+ %slice_f16 = tensor.extract_slice %out_f16 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1386+ : tensor <4 x16 xf16 > to tensor <1 x8 xf16 >
1387+ %slice_i32 = tensor.extract_slice %out_i32 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1388+ : tensor <4 x16 xi32 > to tensor <1 x8 xi32 >
1389+ %compare:2 = iree_linalg_ext.arg_compare {lowering_config = #config_fill_fold }
1390+ dimension (2 ) ins (%in_slice : tensor <1 x8 x128 xf16 >)
1391+ outs (%slice_f16 , %slice_i32 : tensor <1 x8 xf16 >, tensor <1 x8 xi32 >)
1392+ index_base (%c0 : index ) {
1393+ ^bb0 (%lhs: f16 , %rhs: f16 ):
1394+ %cmp = arith.cmpf ogt , %lhs , %rhs : f16
1395+ iree_linalg_ext.yield %cmp : i1
1396+ } -> tensor <1 x8 xf16 >, tensor <1 x8 xi32 >
1397+ scf.forall.in_parallel {
1398+ tensor.parallel_insert_slice %compare#0 into %out_f16 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1399+ : tensor <1 x8 xf16 > into tensor <4 x16 xf16 >
1400+ tensor.parallel_insert_slice %compare#1 into %out_i32 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1401+ : tensor <1 x8 xi32 > into tensor <4 x16 xi32 >
1402+ }
1403+ } {mapping = [#iree_codegen.workgroup_mapping <y >, #iree_codegen.workgroup_mapping <x >]}
1404+ return %result#0 , %result#1 : tensor <4 x16 xf16 >, tensor <4 x16 xi32 >
1405+ }
1406+
1407+ // CHECK-LABEL: func.func @fold_fill_through_block_arg
1408+ // CHECK-DAG: %[[CST_F16:.+]] = arith.constant 0xFC00 : f16
1409+ // CHECK-DAG: %[[CST_I32:.+]] = arith.constant 0 : i32
1410+ // CHECK-DAG: %[[EMPTY_F16:.+]] = tensor.empty() : tensor<4x16xf16>
1411+ // CHECK-DAG: %[[EMPTY_I32:.+]] = tensor.empty() : tensor<4x16xi32>
1412+ // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (4, 16) step (1, 8)
1413+ // CHECK-SAME: shared_outs(%[[OUT_F16:.+]] = %[[EMPTY_F16]], %[[OUT_I32:.+]] = %[[EMPTY_I32]])
1414+ // CHECK: %[[SLICE_F16:.+]] = tensor.extract_slice %[[OUT_F16]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1415+ // CHECK: %[[FILLED_F16:.+]] = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = {{\[\[}}1, 8]]>} ins(%[[CST_F16]] : f16) outs(%[[SLICE_F16]] : tensor<1x8xf16>)
1416+ // CHECK: %[[SLICE_I32:.+]] = tensor.extract_slice %[[OUT_I32]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1417+ // CHECK: %[[FILLED_I32:.+]] = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = {{\[\[}}1, 8]]>} ins(%[[CST_I32]] : i32) outs(%[[SLICE_I32]] : tensor<1x8xi32>)
1418+ // CHECK: scf.forall
1419+ // CHECK-SAME: shared_outs({{.*}} = %[[FILLED_F16]], {{.*}} = %[[FILLED_I32]])
1420+ // CHECK: iree_linalg_ext.arg_compare
0 commit comments