@@ -1364,3 +1364,64 @@ attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPU
13641364// CHECK: scf.forall.in_parallel
13651365// CHECK: tensor.parallel_insert_slice %[[RES]] into %[[OUT0]][%[[OFFSET0]], 0, %[[OFFSET1]]]
13661366// CHECK: {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
1367+
1368+ // -----
1369+
1370+ // Test for FoldExtractSliceOfFillThroughBlockArgPattern:
1371+ // When a forall's shared_out init is a linalg.fill, and we extract a slice
1372+ // from the block argument, the pattern should:
1373+ // 1. Change the forall's init to use the fill's destination (empty tensor)
1374+ // 2. Create a new fill on the extracted slice inside the loop
1375+
1376+ #config_fill_fold = #iree_codegen.lowering_config <tile_sizes = [[1 , 8 ]]>
1377+
1378+ func.func @fold_fill_through_block_arg (%arg0 : tensor <4 x16 x128 xf16 >) -> (tensor <4 x16 xf16 >, tensor <4 x16 xi32 >) {
1379+ %cst = arith.constant 0xFC00 : f16
1380+ %c0_i32 = arith.constant 0 : i32
1381+ %c0 = arith.constant 0 : index
1382+ %empty_f16 = tensor.empty () : tensor <4 x16 xf16 >
1383+ %empty_i32 = tensor.empty () : tensor <4 x16 xi32 >
1384+ %fill_f16 = linalg.fill {lowering_config = #config_fill_fold }
1385+ ins (%cst : f16 ) outs (%empty_f16 : tensor <4 x16 xf16 >) -> tensor <4 x16 xf16 >
1386+ %fill_i32 = linalg.fill {lowering_config = #config_fill_fold }
1387+ ins (%c0_i32 : i32 ) outs (%empty_i32 : tensor <4 x16 xi32 >) -> tensor <4 x16 xi32 >
1388+ %result:2 = scf.forall (%iv0 , %iv1 ) = (0 , 0 ) to (4 , 16 ) step (1 , 8 )
1389+ shared_outs (%out_f16 = %fill_f16 , %out_i32 = %fill_i32 ) -> (tensor <4 x16 xf16 >, tensor <4 x16 xi32 >) {
1390+ %in_slice = tensor.extract_slice %arg0 [%iv0 , %iv1 , 0 ] [1 , 8 , 128 ] [1 , 1 , 1 ]
1391+ : tensor <4 x16 x128 xf16 > to tensor <1 x8 x128 xf16 >
1392+ %slice_f16 = tensor.extract_slice %out_f16 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1393+ : tensor <4 x16 xf16 > to tensor <1 x8 xf16 >
1394+ %slice_i32 = tensor.extract_slice %out_i32 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1395+ : tensor <4 x16 xi32 > to tensor <1 x8 xi32 >
1396+ %compare:2 = iree_linalg_ext.arg_compare {lowering_config = #config_fill_fold }
1397+ dimension (2 ) ins (%in_slice : tensor <1 x8 x128 xf16 >)
1398+ outs (%slice_f16 , %slice_i32 : tensor <1 x8 xf16 >, tensor <1 x8 xi32 >)
1399+ index_base (%c0 : index ) {
1400+ ^bb0 (%lhs: f16 , %rhs: f16 ):
1401+ %cmp = arith.cmpf ogt , %lhs , %rhs : f16
1402+ iree_linalg_ext.yield %cmp : i1
1403+ } -> tensor <1 x8 xf16 >, tensor <1 x8 xi32 >
1404+ scf.forall.in_parallel {
1405+ tensor.parallel_insert_slice %compare#0 into %out_f16 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1406+ : tensor <1 x8 xf16 > into tensor <4 x16 xf16 >
1407+ tensor.parallel_insert_slice %compare#1 into %out_i32 [%iv0 , %iv1 ] [1 , 8 ] [1 , 1 ]
1408+ : tensor <1 x8 xi32 > into tensor <4 x16 xi32 >
1409+ }
1410+ } {mapping = [#iree_codegen.workgroup_mapping <y >, #iree_codegen.workgroup_mapping <x >]}
1411+ return %result#0 , %result#1 : tensor <4 x16 xf16 >, tensor <4 x16 xi32 >
1412+ }
1413+
1414+ // CHECK-LABEL: func.func @fold_fill_through_block_arg
1415+ // CHECK-DAG: %[[CST_F16:.+]] = arith.constant 0xFC00 : f16
1416+ // CHECK-DAG: %[[CST_I32:.+]] = arith.constant 0 : i32
1417+ // CHECK-DAG: %[[EMPTY_F16:.+]] = tensor.empty() : tensor<4x16xf16>
1418+ // CHECK-DAG: %[[EMPTY_I32:.+]] = tensor.empty() : tensor<4x16xi32>
1419+ // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) = (0, 0) to (4, 16) step (1, 8)
1420+ // CHECK-SAME: shared_outs(%[[OUT_F16:.+]] = %[[EMPTY_F16]], %[[OUT_I32:.+]] = %[[EMPTY_I32]])
1421+ // CHECK: %[[SLICE_F16:.+]] = tensor.extract_slice %[[OUT_F16]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1422+ // CHECK: %[[FILLED_F16:.+]] = linalg.fill ins(%[[CST_F16]] : f16) outs(%[[SLICE_F16]] : tensor<1x8xf16>)
1423+ // CHECK: %[[SLICE_I32:.+]] = tensor.extract_slice %[[OUT_I32]][%[[IV0]], %[[IV1]]] [1, 8] [1, 1]
1424+ // CHECK: %[[FILLED_I32:.+]] = linalg.fill ins(%[[CST_I32]] : i32) outs(%[[SLICE_I32]] : tensor<1x8xi32>)
1425+ // CHECK: scf.forall
1426+ // CHECK-SAME: shared_outs({{.*}} = %[[FILLED_F16]], {{.*}} = %[[FILLED_I32]])
1427+ // CHECK: iree_linalg_ext.arg_compare
0 commit comments