今天分享一点MLIR源码处理流程, 看不懂memref到LLVM的处理,就不懂上层的访存优化~
- 创建memeref.ExtractStridedMetadataOp,memref描述展开;
%92:8 = "memref.extract_strided_metadata"(%35) :
(memref<300x100x250xf32>) ->
(memref<f32>, index, index, index, index, index, index, index)
** Insert : 'llvm.extractvalue'(0x5aaf3b0)(%34) <{position = array<i64: 0>}>
** Insert : 'llvm.extractvalue'(0x5aaf440)
** Insert : 'llvm.mlir.undef'(0x5aabff0)%94 = !llvm.struct<(ptr, ptr, i64)>
** Insert : 'llvm.insertvalue'(0x5aaf4d0)(%94, %92) <{position = array<i64: 0>}>
** Insert : 'llvm.insertvalue'(0x5aaf580)
** Insert : 'llvm.mlir.constant'(0x5aaa0e0)
** Insert : 'llvm.insertvalue'(0x5aaf630)
** Insert : 'llvm.extractvalue'(0x5aaf6e0)<{position = array<i64: 2>}> : (!llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>)
** Insert : 'llvm.extractvalue'(0x5aaf770)
** Insert : 'llvm.extractvalue'(0x5aaf800)
** Insert : 'llvm.extractvalue'(0x5aaf890)
** Insert : 'llvm.extractvalue'(0x5aaf920)
** Insert : 'llvm.extractvalue'(0x5aaf9b0)
** Insert : 'llvm.extractvalue'(0x5aafa40)
- 通过源步长和偏移计算子视图步长和偏移:
finaloffset = s0 + s1 * s2 + s3 * s4 + s5 * s6
= 0 + wkg_z * 25000 + wkg_y * 50 * 250 + 0 * 1
= wkg_z * 25000 + wkg_y* 12500
value0: 0
value1: %workgroup_id_z = hal.interface.workgroup.id[2] : index
value2: 25000 : index
value3: %2 = affine.apply affine_map<()[s0] -> (s0 * 50)>()[%workgroup_id_y]
value4: 250 : index
value5: 0
value6: 1
%93 = "affine.apply"(%81, %77) <{
map = affine_map<()[s0, s1] -> (s0 * 25000 + s1 * 12500)>}> : (index, index) -> index
1 : i64
50 : i64
250 : i64
25000 : index
250 : index
1 : index
** Insert : 'arith.constant'(0x5aac050)<{value = 25000 : index}> : () -> index
"llvm.mlir.constant"() <{value = 25000 : index}> : () -> i64
** Insert : 'arith.muli'(0x5ab2770)
%89 = llvm.mul %69, %88 : i64
** Insert : 'arith.constant'(0x5aac0c0)
** Insert : 'arith.muli'(0x5ab2820)
** Insert : 'arith.addi'(0x5ab28d0)
** Replace : 'affine.apply'(0x5aaf300)
- 创建memref.reinterpret_cast对ExtractStridedMetadataOp的0号内存指针做重新解释;
%118 = "memref.reinterpret_cast"(%106#0, %117)
<{ operand_segment_sizes = array<i32: 1, 1, 0, 0>,
static_offsets = array<i64: -9223372036854775808>,
static_sizes = array<i64: 1, 50, 250>,
static_strides = array<i64: 25000, 250, 1>
}> : (memref<f32>, index)
-> memref<1x50x250xf32, strided<[25000, 250, 1], offset: ?>>
** Insert : 'llvm.mlir.undef'(0x5ab3310)
!llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
** Insert : 'llvm.extractvalue'(0x5ab2b40)
** Insert : 'llvm.extractvalue'(0x5ab2bd0)
** Insert : 'llvm.insertvalue'(0x5ab2c60)
** Insert : 'llvm.insertvalue'(0x5ab2d10)
** Insert : 'llvm.insertvalue'(0x5ab2dc0)
** Insert : 'llvm.mlir.constant'(0x5ab2e70)
** Insert : 'llvm.insertvalue'(0x5ab2ed0) 3,0
** Insert : 'llvm.mlir.constant'(0x5ab2f80)
** Insert : 'llvm.insertvalue'(0x5ab2fe0) 4,0
** Insert : 'llvm.mlir.constant'(0x5ab3090)
** Insert : 'llvm.insertvalue'(0x5ab30f0) 3,1
** Insert : 'llvm.mlir.constant'(0x5ab0230)
** Insert : 'llvm.insertvalue'(0x5ab0290) 4,1
** Insert : 'llvm.mlir.constant'(0x5ab0340)
** Insert : 'llvm.insertvalue'(0x5ab03a0) 3,2
** Insert : 'llvm.mlir.constant'(0x5ab0450)
** Insert : 'llvm.insertvalue'(0x5ab04b0) 4,2
** Replace : 'memref.reinterpret_cast'(0x5aa9d30)