Skip to content

[GPU Codegen] Reduction Optimization: Expand iteration space of innermost reduction dimension #22153

@efric

Description

@efric

Consider the following linalg.generic, which is representative of a typical matvec:

 %7 = linalg.generic {indexing_maps = [#map, #map1, #map2], 
 iterator_types = ["parallel", "parallel", "reduction"]} 
 ins(%3, %4 : tensor<4x16384xf16>, tensor<6656x16384xf16>) 
 outs(%6 : tensor<4x6656xf32>) 
 attrs =  {lowering_config = 
 #iree_gpu.lowering_config<{
 lane_basis = [[1, 1, 64], [0, 1, 2]], 
 partial_reduction = [0, 0, 512], 
 subgroup_basis = [[1, 1, 1], [0, 1, 2]], 
 thread = [0, 0, 8], 
 workgroup = [2, 1, 0]}} {                                                                                                                                                                                
^bb0(%in: f16, %in_0: f16, %out: f32):
            %8 = arith.extf %in : f16 to f32
            %9 = arith.extf %in_0 : f16 to f32
            %10 = arith.mulf %8, %9 : f32                                                                                                                                                                                 
            %11 = arith.addf %out, %10 : f32
            linalg.yield %11 : f32
} -> tensor<4x6656xf32> 

Currently, post vector-distribute, we see the following:

...
  scf.forall (%arg0, %arg1) = (0, 0) to (4, 6656) step (2, 1) {
    %subview = memref.subview %7[%arg0, %arg1] [2, 1] [1, 1] : memref<4x6656xf32, #amdgpu.address_space<fat_raw_buffer>> to memref<2xf32, strided<[6656], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>
    %8:3 = affine.delinearize_index %1 into (1, 64) : index, index, index
    %9:2 = affine.delinearize_index %1 into (64) : index, index
    %10 = arith.muli %8#1, %c512 : index
    %11 = vector.broadcast %10 : index to vector<8xindex>
    %12 = arith.addi %11, %cst_2 : vector<8xindex>
    %13 = arith.muli %9#1, %c8 : index
    %14 = vector.broadcast %13 : index to vector<8xindex>
    %15 = arith.addi %12, %14 : vector<8xindex>
    %16 = vector.shape_cast %15 : vector<8xindex> to vector<1x1x8xindex>
    %17 = arith.cmpi ult, %16, %cst_5 : vector<1x1x8xindex>
    %18 = vector.broadcast %17 : vector<1x1x8xi1> to vector<2x1x1x1x1x8xi1>
    %19 = scf.for %arg2 = %c0 to %c16384 step %c512 iter_args(%arg3 = %cst_4) -> (vector<2x1x1x1x1x8xf32>) {
      %35:4 = affine.delinearize_index %1 into (1, 1, 64) : index, index, index, index
      %36:3 = affine.delinearize_index %1 into (64, 1) : index, index, index
      %37 = affine.linearize_index [%35#2, %c0, %c0, %36#2, %arg0] by (1, 2, 1, 1, 1) : index
      %38 = affine.linearize_index [%35#1, %c0, %c0, %36#1, %arg2] by (1, 1, 1, 64, 8) : index
      %39 = vector.transfer_read %3[%37, %38], %0 {in_bounds = [true, true]} : memref<4x16384xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<1x8xf16>
      %40 = vector.insert_strided_slice %39, %cst_1 {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<2x1x1x1x1x8xf16>
      %41 = affine.linearize_index [%35#2, %c1, %c0, %36#2, %arg0] by (1, 2, 1, 1, 1) : index
      %42 = affine.linearize_index [%35#1, %c0, %c0, %36#1, %arg2] by (1, 1, 1, 64, 8) : index
      %43 = vector.transfer_read %3[%41, %42], %0 {in_bounds = [true, true]} : memref<4x16384xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<1x8xf16>
      %44 = vector.insert_strided_slice %43, %40 {offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<2x1x1x1x1x8xf16>
      %45:3 = affine.delinearize_index %1 into (1, 64) : index, index, index
      %46:2 = affine.delinearize_index %1 into (64) : index, index
      %47 = affine.linearize_index [%45#1, %c0, %c0, %46#1, %arg2] by (1, 1, 1, 64, 8) : index
      %48 = vector.transfer_read %5[%arg1, %47], %0 {in_bounds = [true]} : memref<6656x16384xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<8xf16>
      %49 = vector.insert_strided_slice %48, %cst_0 {offsets = [0, 0, 0], strides = [1]} : vector<8xf16> into vector<1x1x8xf16>
      %50 = arith.extf %44 : vector<2x1x1x1x1x8xf16> to vector<2x1x1x1x1x8xf32>
      %51 = arith.extf %49 : vector<1x1x8xf16> to vector<1x1x8xf32>
      %52 = vector.broadcast %51 : vector<1x1x8xf32> to vector<2x1x1x1x1x8xf32>
      %53 = arith.mulf %50, %52 : vector<2x1x1x1x1x8xf32>
      %54 = arith.select %18, %53, %cst_4 : vector<2x1x1x1x1x8xi1>, vector<2x1x1x1x1x8xf32>
      %55 = arith.addf %arg3, %54 : vector<2x1x1x1x1x8xf32>
      scf.yield %55 : vector<2x1x1x1x1x8xf32>
    }
  }
    %20 = vector.multi_reduction <add>, %19, %cst_3 [1, 3, 5] : vector<2x1x1x1x1x8xf32> to vector<2x1x1xf32>
    %21 = vector.extract %20[0, 0, 0] : f32 from vector<2x1x1xf32>
    %22 = gpu.subgroup_reduce  add %21 cluster(size = 64) : (f32) -> f32
    %23 = vector.insert %22, %cst [0] : f32 into vector<2xf32>
    %24 = vector.extract %20[1, 0, 0] : f32 from vector<2x1x1xf32>
    %25 = gpu.subgroup_reduce  add %24 cluster(size = 64) : (f32) -> f32
...

We want to remove the extraneous vector.multi_reduction at the end of the loop. This will help us

  1. Reduce how much we need to carry w.r.t the loop-carried accumulator (in this case, from vector<2x8xf32> -> vector<2x1xf32>, though this can be more significant depending on the lowering_config, reducing register pressure and )
  2. Enables chain FMA within the loop.

Approach: Implement iteration space expansion along the inner most reduction dimension s.t r -> r0, r1 by introducing a new lowering_config for reductions. For example, consider the following modification to the same linalg.generic, in which we've modify:

partial_reduction = [0, 0, 512] -> partial_reduction = [0, 0, 64, 0], expand_iteration_space = [0, 0, 0, 8] i.e 512 -> 64, 8.

At the linalg level, the linalg.generic is lowered to the following, pre and post tiling.

// Pre Padding+Tiling:
// ... Set up 
  %9 = scf.forall (%arg0, %arg1) = (0, 0) to (4, 6656) step (2, 1) shared_outs(%arg2 = %8) 
      -> (tensor<4x6656xf32>) {
    %extracted_slice = tensor.extract_slice %6[%arg0, 0] [2, 16384] [1, 1] 
        : tensor<4x16384xf16> to tensor<2x16384xf16>
    %extracted_slice_0 = tensor.extract_slice %7[%arg1, 0] [1, 16384] [1, 1] 
        : tensor<6656x16384xf16> to tensor<1x16384xf16>
    %extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [2, 1] [1, 1] 
        : tensor<4x6656xf32> to tensor<2x1xf32>
    %10 = linalg.fill ins(%cst : f32) outs(%extracted_slice_1 : tensor<2x1xf32>) -> tensor<2x1xf32>
    %11 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 
                         affine_map<(d0, d1, d2) -> (d1, d2)>, 
                         affine_map<(d0, d1, d2) -> (d0, d1)>], 
        iterator_types = ["parallel", "parallel", "reduction"]} 
        ins(%extracted_slice, %extracted_slice_0 : tensor<2x16384xf16>, tensor<1x16384xf16>) 
        outs(%10 : tensor<2x1xf32>) 
        attrs =  {lowering_config = #iree_gpu.lowering_config<{
            lane_basis = [[1, 1, 64], [0, 1, 2]], 
            partial_reduction = [0, 0, 512], 
            subgroup_basis = [[1, 1, 1], [0, 1, 2]], 
            thread = [0, 0, 8], 
            workgroup = [2, 1, 0]}>} {
    ^bb0(%in: f16, %in_2: f16, %out: f32):
      %12 = arith.extf %in : f16 to f32
      %13 = arith.extf %in_2 : f16 to f32
      %14 = arith.mulf %12, %13 : f32
      %15 = arith.addf %out, %14 : f32
      linalg.yield %15 : f32
    } -> tensor<2x1xf32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %11 into %arg2[%arg0, %arg1] [2, 1] [1, 1] 
          : tensor<2x1xf32> into tensor<4x6656xf32>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
  iree_codegen.store_to_buffer %9, %5 : tensor<4x6656xf32> 
      into memref<4x6656xf32, #amdgpu.address_space<fat_raw_buffer>>
  return


// Post Tiling
// ... Setup
  %9 = scf.forall (%arg0, %arg1) = (0, 0) to (4, 6656) step (2, 1) shared_outs(%arg2 = %8) 
      -> (tensor<4x6656xf32>) {
    %extracted_slice = tensor.extract_slice %arg2[%arg0, %arg1] [2, 1] [1, 1] 
        : tensor<4x6656xf32> to tensor<2x1xf32>
    %10 = linalg.fill ins(%cst : f32) outs(%extracted_slice : tensor<2x1xf32>) -> tensor<2x1xf32>
    %11 = tensor.empty() : tensor<2x1xf32>
    %12 = linalg.copy ins(%10 : tensor<2x1xf32>) outs(%11 : tensor<2x1xf32>) -> tensor<2x1xf32>
    %13 = tensor.empty() : tensor<2x1x512xf32>
    %14 = linalg.fill ins(%cst : f32) outs(%13 : tensor<2x1x512xf32>) -> tensor<2x1x512xf32>
    %15 = scf.for %arg3 = %c0 to %c16384 step %c512 iter_args(%arg4 = %14) -> (tensor<2x1x512xf32>) {
      %extracted_slice_0 = tensor.extract_slice %6[%arg0, %arg3] [2, 512] [1, 1] 
          : tensor<4x16384xf16> to tensor<2x512xf16> -> 2x64x8
      %16 = tensor.empty() : tensor<2x512xf16>
      %17 = linalg.copy ins(%extracted_slice_0 : tensor<2x512xf16>) outs(%16 : tensor<2x512xf16>) 
          -> tensor<2x512xf16>
      %extracted_slice_1 = tensor.extract_slice %7[%arg1, %arg3] [1, 512] [1, 1] 
          : tensor<6656x16384xf16> to tensor<1x512xf16>
      %18 = tensor.empty() : tensor<1x512xf16>
      %19 = linalg.copy ins(%extracted_slice_1 : tensor<1x512xf16>) outs(%18 : tensor<1x512xf16>) 
          -> tensor<1x512xf16>
      %20 = linalg.generic {
          indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, 
                           affine_map<(d0, d1, d2) -> (d1, d2)>, 
                           affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 
          iterator_types = ["parallel", "parallel", "parallel"] } 
          ins(%17, %19 : tensor<2x512xf16>, tensor<1x512xf16>) 
          outs(%arg4 : tensor<2x1x512xf32>) 
          attrs =  {lowering_config = #iree_gpu.lowering_config<{
              lane_basis = [[1, 1, 64], [0, 1, 2]], 
              partial_reduction = [0, 0, 512], 
              subgroup_basis = [[1, 1, 1], [0, 1, 2]], 
              thread = [0, 0, 8], 
              workgroup = [2, 1, 0]}>} {
      ^bb0(%in: f16, %in_2: f16, %out: f32):
        %21 = arith.extf %in : f16 to f32
        %22 = arith.extf %in_2 : f16 to f32
        %23 = arith.mulf %21, %22 : f32
        %24 = linalg.index 2 : index
        %25 = arith.cmpi ult, %24, %c16384 : index
        %26 = arith.select %25, %23, %cst : f32
        %27 = arith.addf %out, %26 : f32
        linalg.yield %27 : f32
      } -> tensor<2x1x512xf32>
      scf.yield %20 : tensor<2x1x512xf32>
    }

    %reduced = linalg.reduce ins(%15 : tensor<2x1x512xf32>) outs(%12 : tensor<2x1xf32>) 
        dimensions = [2] 
      (%in: f32, %init: f32) {
        %16 = arith.addf %in, %init : f32
        linalg.yield %16 : f32
      }
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %reduced into %arg2[%arg0, %arg1] [2, 1] [1, 1] 
          : tensor<2x1xf32> into tensor<4x6656xf32>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
  iree_codegen.store_to_buffer %9, %5 : tensor<4x6656xf32> 
      into memref<4x6656xf32, #amdgpu.address_space<fat_raw_buffer>>
  return

Now, with expansion, we have dim2 (K dimension) which was size 16384 -> 2048, 8. 8 is chosen based on the number of elements we are loading per thread. The number of partial accumulators within the new outer reduction is now 64. Note we will still have the same number of iterations w.r.t to the reduction loop (16384 / 512 == 2048 / 64).

// Pre padding + tiling
// ... Setup stays the same
  %9 = scf.forall (%arg0, %arg1) = (0, 0) to (4, 6656) step (2, 1) shared_outs(%arg2 = %8) 
      -> (tensor<4x6656xf32>) {
    %extracted_slice = tensor.extract_slice %6[%arg0, 0] [2, 16384] [1, 1] 
        : tensor<4x16384xf16> to tensor<2x16384xf16>
    %extracted_slice_0 = tensor.extract_slice %7[%arg1, 0] [1, 16384] [1, 1] 
        : tensor<6656x16384xf16> to tensor<1x16384xf16>
    %extracted_slice_1 = tensor.extract_slice %arg2[%arg0, %arg1] [2, 1] [1, 1] 
        : tensor<4x6656xf32> to tensor<2x1xf32>
    // Expand by expand_dims
    %expanded_slice = tensor.expand_shape %extracted_slice output_shape [2, 2048, 8] 
        : tensor<2x16384xf16> -> tensor<2x2048x8xf16>
    %expanded_slice_0 = tensor.expand_shape %extracted_slice_0 output_shape [1, 2048, 8] 
        : tensor<1x16384xf16> -> tensor<1x2048x8xf16>
    %10 = linalg.fill ins(%cst : f32) outs(%extracted_slice_1 : tensor<2x1xf32>) -> tensor<2x1xf32>
    %11 = linalg.generic {
        indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, 
                         affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, 
                         affine_map<(d0, d1, d2, d3) -> (d0, d1)>], 
        iterator_types = ["parallel", "parallel", "reduction", "reduction"]} 
        ins(%expanded_slice, %expanded_slice_0 : tensor<2x2048x8xf16>, tensor<1x2048x8xf16>) 
        outs(%10 : tensor<2x1xf32>) 
        attrs =  {lowering_config = #iree_gpu.lowering_config<{
            lane_basis = [[1, 1, 64, 1], [0, 1, 2, 3]], 
            expand_dims = [0, 0, 0 ,8], 
            partial_reduction = [0, 0, 64, 0], 
            subgroup_basis = [[1, 1, 1, 1], [0, 1, 2, 3]], 
            thread = [0, 0, 0, 8], 
            workgroup = [2, 1, 0]}>} {
    ^bb0(%in: f16, %in_2: f16, %out: f32):
      %12 = arith.extf %in : f16 to f32
      %13 = arith.extf %in_2 : f16 to f32
      %14 = arith.mulf %12, %13 : f32
      %15 = arith.addf %out, %14 : f32
      linalg.yield %15 : f32
    } -> tensor<2x1xf32>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %11 into %arg2[%arg0, %arg1] [2, 1] [1, 1] 
          : tensor<2x1xf32> into tensor<4x6656xf32>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
  iree_codegen.store_to_buffer %9, %5 : tensor<4x6656xf32> 
      into memref<4x6656xf32, #amdgpu.address_space<fat_raw_buffer>>
  return


// Post tiling (still using PartialReductionOuterReduction reduction strategy)
// ... Setup stays the same
  %9 = scf.forall (%arg0, %arg1) = (0, 0) to (4, 6656) step (2, 1) shared_outs(%arg2 = %8) 
      -> (tensor<4x6656xf32>) { 
    %extracted_slice = tensor.extract_slice %arg2[%arg0, %arg1] [2, 1] [1, 1] 
        : tensor<4x6656xf32> to tensor<2x1xf32>
    %10 = linalg.fill ins(%cst : f32) outs(%extracted_slice : tensor<2x1xf32>) -> tensor<2x1xf32>
    %11 = tensor.empty() : tensor<2x1xf32>
    %12 = linalg.copy ins(%10 : tensor<2x1xf32>) outs(%11 : tensor<2x1xf32>) 
        -> tensor<2x1xf32> //output
    %13 = tensor.empty() : tensor<2x1x64xf32>
    %14 = linalg.fill ins(%cst : f32) outs(%13 : tensor<2x1x64xf32>) 
        -> tensor<2x1x64xf32> // new loop carried accumulator 

    %15 = scf.for %arg3 = %c0 to %c2048 step %c64 iter_args(%arg4 = %14) -> (tensor<2x1x64xf32>) { 
      //so a reduction will happen inside the loop
      %extracted_slice_0 = tensor.extract_slice %6[%arg0, %arg3, 0] [2, 64, 8] [1, 1, 1] 
          : tensor<4x2048x8xf16> to tensor<2x64x8xf16>
      %16 = tensor.empty() : tensor<2x64x8xf16>
      %17 = linalg.copy ins(%extracted_slice_0 : tensor<2x64x8f16>) outs(%16 : tensor<2x64x8xf16>) 
          -> tensor<2x64x8xf16>

      %extracted_slice_1 = tensor.extract_slice %7[%arg1, %arg3, 0] [1, 64, 8] [1, 1, 1] 
          : tensor<6656x2048x8xf16> to tensor<1x64x8xf16>
      %18 = tensor.empty() : tensor<1x64x8xf16>
      %19 = linalg.copy ins(%extracted_slice_1 : tensor<1x64x8xf16>) outs(%18 : tensor<1x64x8xf16>) 
          -> tensor<1x64x8xf16>
      
      %20 = linalg.generic {
          indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, 
                           affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, 
                           affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], 
          iterator_types = ["parallel", "parallel", "parallel", "reduction"] } 
          ins(%17, %19 : tensor<2x64x8xf16>, tensor<1x64x8xf16>) 
          outs(%arg4 : tensor<2x1x64xf32>) 
          attrs =  {lowering_config = #iree_gpu.lowering_config<{ 
              lane_basis = [[1, 1, 64, 1], [0, 1, 2, 3]], 
              expand_dims = [0, 0, 0 ,8], 
              partial_reduction = [0, 0, 64, 0], 
              subgroup_basis = [[1, 1, 1, 1], [0, 1, 2, 3]], 
              thread = [0, 0, 0, 8], 
              workgroup = [2, 1, 0]}{
      ^bb0(%in: f16, %in_2: f16, %out: f32):
        %21 = arith.extf %in : f16 to f32
        %22 = arith.extf %in_2 : f16 to f32
        %23 = arith.mulf %21, %22 : f32
        %24 = linalg.index 2 : index
        %25 = arith.cmpi ult, %24, %c16384 : index
        %26 = arith.select %25, %23, %cst : f32
        %27 = arith.addf %out, %26 : f32
        linalg.yield %27 : f32
      } -> tensor<2x1x64xf32>
      scf.yield %20 : tensor<2x1x64xf32>
    }

    %reduced = linalg.reduce ins(%15 : tensor<2x1x64xf32>) outs(%12 : tensor<2x1xf32>) 
        dimensions = [2] // this can be just gpu.subgroup_reduces; thread level reduce now inside
      (%in: f32, %init: f32) {
        %16 = arith.addf %in, %init : f32
        linalg.yield %16 : f32
    }

    scf.forall.in_parallel {
      tensor.parallel_insert_slice %reduced into %arg2[%arg0, %arg1] [2, 1] [1, 1] 
          : tensor<2x1xf32> into tensor<4x6656xf32>
    }

  } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
  iree_codegen.store_to_buffer %9, %5 : tensor<4x6656xf32> 
      into memref<4x6656xf32, #amdgpu.address_space<fat_raw_buffer>>
  return
}

Metadata

Metadata

Assignees

Labels

codegenShared code generation infrastructure and dialectsperformance ⚡Performance/optimization related work across the compiler and runtime

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions