首先编译器版本低导致的性能问题。
其次大部分同学没有汇编优化的技巧或者说没有看看反汇编代码导致的变量和向量寄存器的随便使用,(类似全部命名位0-110,反正通过下标取就好了)这将导致性能的恶化。不断的push pop stack开销,简单的push pop并不会导致很严重的性能恶化,反而导致了严重的指令流水线以及等待问题。
好处:指令集使用和指令重排更可控,同时解决了低版本编译器的性能不足问题!
缺点:没人想写!
x86实例一:
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst3;
__m256 dst6;
__m256 dst9;
__m256 dst1;
__m256 dst4;
__m256 dst7;
__m256 dst10;
__m256 dst2;
__m256 dst5;
__m256 dst8;
__m256 dst11;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0);
dst9 = _mm256_load_ps(dst + 3 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8);
dst10 = _mm256_load_ps(dst + 3 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16);
dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16);
dst11 = _mm256_load_ps(dst + 3 * dst_stride + 16);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
dst10 = _mm256_setzero_ps();
dst11 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 8);
dst6 = _mm256_load_ps(bias + 16);
dst9 = _mm256_load_ps(bias + 24);
dst1 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 8);
dst7 = _mm256_load_ps(bias + 16);
dst10 = _mm256_load_ps(bias + 24);
dst2 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 8);
dst8 = _mm256_load_ps(bias + 16);
dst11 = _mm256_load_ps(bias + 24);
}
for (int i = 0; i < (deep >> 3); ++i) {
// 3x32block:寄存器的使用设计
// 必须先对src进行读取才不会出现寄存器的溢出问题,而不是先对weight进行读取,否则会性能恶化。
// 但是会导致严重的流水线依赖问题(weight的加载),尝试使用4x24block?
// 再看看32正好是两个cache line,而24是1.5个cache line,多线程存在伪共享问题。
// 4x24block:需要先加载weight,然后在加载src就不会出现寄存器溢出问题。
// 同时4x24block需要使用4个cache line来存src(再有其他模块下性能会恶化更严重),而3x32只需要3
// 个cache line,综合考虑3x32
// 实际测试却是4x24效果更好!! 可能流水更好??? 在研究吧!!
// bock0
__m256 src00 = _mm256_set1_ps(*(src + 0));
__m256 src10 = _mm256_set1_ps(*(src + 8));
__m256 src20 = _mm256_set1_ps(*(src + 16));
__m256 weight00 = _mm256_load_ps(weight + 0);
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 weight10 = _mm256_load_ps(weight + 8);
dst3 = _mm256_fmadd_ps(dst3, src00, weight10);
dst4 = _mm256_fmadd_ps(dst4, src10, weight10);
dst5 = _mm256_fmadd_ps(dst5, src20, weight10);
__m256 weight20 = _mm256_load_ps(weight + 16);
dst6 = _mm256_fmadd_ps(dst6, src00, weight20);
dst7 = _mm256_fmadd_ps(dst7, src10, weight20);
dst8 = _mm256_fmadd_ps(dst8, src20, weight20);
__m256 weight30 = _mm256_load_ps(weight + 24);
dst9 = _mm256_fmadd_ps(dst9, src00, weight30);
dst10 = _mm256_fmadd_ps(dst10, src10, weight30);
dst11 = _mm256_fmadd_ps(dst11, src20, weight30);
// bock1
__m256 src01 = _mm256_set1_ps(*(src + 1));
__m256 src11 = _mm256_set1_ps(*(src + 9));
__m256 src21 = _mm256_set1_ps(*(src + 17));
__m256 weight01 = _mm256_load_ps(weight + 32);
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 weight11 = _mm256_load_ps(weight + 40);
dst3 = _mm256_fmadd_ps(dst3, src01, weight11);
dst4 = _mm256_fmadd_ps(dst4, src11, weight11);
dst5 = _mm256_fmadd_ps(dst5, src21, weight11);
__m256 weight21 = _mm256_load_ps(weight + 48);
dst6 = _mm256_fmadd_ps(dst6, src01, weight21);
dst7 = _mm256_fmadd_ps(dst7, src11, weight21);
dst8 = _mm256_fmadd_ps(dst8, src21, weight21);
__m256 weight31 = _mm256_load_ps(weight + 56);
dst9 = _mm256_fmadd_ps(dst9, src01, weight31);
dst10 = _mm256_fmadd_ps(dst10, src11, weight31);
dst11 = _mm256_fmadd_ps(dst11, src21, weight31);
// bock2
__m256 src02 = _mm256_set1_ps(*(src + 2));
__m256 src12 = _mm256_set1_ps(*(src + 10));
__m256 src22 = _mm256_set1_ps(*(src + 18));
__m256 weight02 = _mm256_load_ps(weight + 64);
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 weight12 = _mm256_load_ps(weight + 72);
dst3 = _mm256_fmadd_ps(dst3, src02, weight12);
dst4 = _mm256_fmadd_ps(dst4, src12, weight12);
dst5 = _mm256_fmadd_ps(dst5, src22, weight12);
__m256 weight22 = _mm256_load_ps(weight + 80);
dst6 = _mm256_fmadd_ps(dst6, src02, weight22);
dst7 = _mm256_fmadd_ps(dst7, src12, weight22);
dst8 = _mm256_fmadd_ps(dst8, src22, weight22);
__m256 weight32 = _mm256_load_ps(weight + 88);
dst9 = _mm256_fmadd_ps(dst9, src02, weight32);
dst10 = _mm256_fmadd_ps(dst10, src12, weight32);
dst11 = _mm256_fmadd_ps(dst11, src22, weight32);
// bock3
__m256 src03 = _mm256_set1_ps(*(src + 3));
__m256 src13 = _mm256_set1_ps(*(src + 11));
__m256 src23 = _mm256_set1_ps(*(src + 19));
__m256 weight03 = _mm256_load_ps(weight + 96);
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 weight13 = _mm256_load_ps(weight + 104);
dst3 = _mm256_fmadd_ps(dst3, src03, weight13);
dst4 = _mm256_fmadd_ps(dst4, src13, weight13);
dst5 = _mm256_fmadd_ps(dst5, src23, weight13);
__m256 weight23 = _mm256_load_ps(weight + 112);
dst6 = _mm256_fmadd_ps(dst6, src03, weight23);
dst7 = _mm256_fmadd_ps(dst7, src13, weight23);
dst8 = _mm256_fmadd_ps(dst8, src23, weight23);
__m256 weight33 = _mm256_load_ps(weight + 120);
dst9 = _mm256_fmadd_ps(dst9, src03, weight33);
dst10 = _mm256_fmadd_ps(dst10, src13, weight33);
dst11 = _mm256_fmadd_ps(dst11, src23, weight33);
// bock4
__m256 src04 = _mm256_set1_ps(*(src + 4));
__m256 src14 = _mm256_set1_ps(*(src + 12));
__m256 src24 = _mm256_set1_ps(*(src + 20));
__m256 weight04 = _mm256_load_ps(weight + 128);
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 weight14 = _mm256_load_ps(weight + 136);
dst3 = _mm256_fmadd_ps(dst3, src04, weight14);
dst4 = _mm256_fmadd_ps(dst4, src14, weight14);
dst5 = _mm256_fmadd_ps(dst5, src24, weight14);
__m256 weight24 = _mm256_load_ps(weight + 144);
dst6 = _mm256_fmadd_ps(dst6, src04, weight24);
dst7 = _mm256_fmadd_ps(dst7, src14, weight24);
dst8 = _mm256_fmadd_ps(dst8, src24, weight24);
__m256 weight34 = _mm256_load_ps(weight + 152);
dst9 = _mm256_fmadd_ps(dst9, src04, weight34);
dst10 = _mm256_fmadd_ps(dst10, src14, weight34);
dst11 = _mm256_fmadd_ps(dst11, src24, weight34);
// bock5
__m256 src05 = _mm256_set1_ps(*(src + 5));
__m256 src15 = _mm256_set1_ps(*(src + 13));
__m256 src25 = _mm256_set1_ps(*(src + 21));
__m256 weight05 = _mm256_load_ps(weight + 160);
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 weight15 = _mm256_load_ps(weight + 168);
dst3 = _mm256_fmadd_ps(dst3, src05, weight15);
dst4 = _mm256_fmadd_ps(dst4, src15, weight15);
dst5 = _mm256_fmadd_ps(dst5, src25, weight15);
__m256 weight25 = _mm256_load_ps(weight + 176);
dst6 = _mm256_fmadd_ps(dst6, src05, weight25);
dst7 = _mm256_fmadd_ps(dst7, src15, weight25);
dst8 = _mm256_fmadd_ps(dst8, src25, weight25);
__m256 weight35 = _mm256_load_ps(weight + 184);
dst9 = _mm256_fmadd_ps(dst9, src05, weight35);
dst10 = _mm256_fmadd_ps(dst10, src15, weight35);
dst11 = _mm256_fmadd_ps(dst11, src25, weight35);
// bock6
__m256 src06 = _mm256_set1_ps(*(src + 6));
__m256 src16 = _mm256_set1_ps(*(src + 14));
__m256 src26 = _mm256_set1_ps(*(src + 22));
__m256 weight06 = _mm256_load_ps(weight + 192);
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 weight16 = _mm256_load_ps(weight + 200);
dst3 = _mm256_fmadd_ps(dst3, src06, weight16);
dst4 = _mm256_fmadd_ps(dst4, src16, weight16);
dst5 = _mm256_fmadd_ps(dst5, src26, weight16);
__m256 weight26 = _mm256_load_ps(weight + 208);
dst6 = _mm256_fmadd_ps(dst6, src06, weight26);
dst7 = _mm256_fmadd_ps(dst7, src16, weight26);
dst8 = _mm256_fmadd_ps(dst8, src26, weight26);
__m256 weight36 = _mm256_load_ps(weight + 216);
dst9 = _mm256_fmadd_ps(dst9, src06, weight36);
dst10 = _mm256_fmadd_ps(dst10, src16, weight36);
dst11 = _mm256_fmadd_ps(dst11, src26, weight36);
// bock7
__m256 src07 = _mm256_set1_ps(*(src + 7));
__m256 src17 = _mm256_set1_ps(*(src + 15));
__m256 src27 = _mm256_set1_ps(*(src + 23));
__m256 weight07 = _mm256_load_ps(weight + 224);
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 weight17 = _mm256_load_ps(weight + 232);
dst3 = _mm256_fmadd_ps(dst3, src07, weight17);
dst4 = _mm256_fmadd_ps(dst4, src17, weight17);
dst5 = _mm256_fmadd_ps(dst5, src27, weight17);
__m256 weight27 = _mm256_load_ps(weight + 240);
dst6 = _mm256_fmadd_ps(dst6, src07, weight27);
dst7 = _mm256_fmadd_ps(dst7, src17, weight27);
dst8 = _mm256_fmadd_ps(dst8, src27, weight27);
__m256 weight37 = _mm256_load_ps(weight + 248);
dst9 = _mm256_fmadd_ps(dst9, src07, weight37);
dst10 = _mm256_fmadd_ps(dst10, src17, weight37);
dst11 = _mm256_fmadd_ps(dst11, src27, weight37);
src = src + src_stride;
weight += 1024;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst10 = _mm256_min_ps(dst10, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst11 = _mm256_min_ps(dst11, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 1 * src_stride + 0, dst3);
_mm256_store_ps(dst + 1 * src_stride + 8, dst4);
_mm256_store_ps(dst + 1 * src_stride + 16, dst5);
_mm256_store_ps(dst + 2 * src_stride + 0, dst6);
_mm256_store_ps(dst + 2 * src_stride + 8, dst7);
_mm256_store_ps(dst + 2 * src_stride + 16, dst8);
_mm256_store_ps(dst + 3 * src_stride + 0, dst9);
_mm256_store_ps(dst + 3 * src_stride + 8, dst10);
_mm256_store_ps(dst + 3 * src_stride + 16, dst11);
}
x86实例二:
void TiledC8MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic8, size_t oc8) {
const float *src_tmp = src;
for (int i = 0; i < oc8; ++i) {
src = src_tmp;
// 固定寄存器!!
register __m256 dst1 asm("ymm0") = _mm256_setzero_ps();
register __m256 dst2 asm("ymm1") = _mm256_setzero_ps();
register __m256 dst3 asm("ymm2") = _mm256_setzero_ps();
register __m256 dst4 asm("ymm3") = _mm256_setzero_ps();
register __m256 dst5 asm("ymm4") = _mm256_setzero_ps();
register __m256 dst6 asm("ymm5") = _mm256_setzero_ps();
register __m256 dst7 asm("ymm6") = _mm256_setzero_ps();
register __m256 dst8 asm("ymm7") = _mm256_setzero_ps();
for (size_t ic8_tmp = 0; ic8_tmp < ic8; ++ic8_tmp) {
#ifndef ENABLE_DEBUG
asm volatile(
// 1
"vmovups (%1), %%ymm8\n"
"vbroadcastss (%0), %%ymm9\n"
"vbroadcastss 32(%0), %%ymm10\n"
"vbroadcastss 64(%0), %%ymm11\n"
"vbroadcastss 96(%0), %%ymm12\n"
"vbroadcastss 128(%0), %%ymm13\n"
"vbroadcastss 160(%0), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n"
"vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n"
"vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n"
"vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n"
"vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n"
"vbroadcastss 192(%0), %%ymm9\n"
"vbroadcastss 224(%0), %%ymm10\n"
"vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n"
"vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n"
// 2
"vmovups 32(%1), %%ymm15\n"
"vbroadcastss 4(%0), %%ymm11\n"
"vbroadcastss 36(%0), %%ymm12\n"
"vbroadcastss 68(%0), %%ymm13\n"
"vbroadcastss 100(%0), %%ymm14\n"
"vbroadcastss 132(%0), %%ymm9\n"
"vbroadcastss 164(%0), %%ymm10\n"
"vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n"
"vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n"
"vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n"
"vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n"
"vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n"
"vbroadcastss 196(%0), %%ymm11\n"
"vbroadcastss 228(%0), %%ymm12\n"
"vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
// 3
"vmovups 64(%1), %%ymm8\n"
"vbroadcastss 8(%0), %%ymm13\n"
"vbroadcastss 40(%0), %%ymm14\n"
"vbroadcastss 72(%0), %%ymm9\n"
"vbroadcastss 104(%0), %%ymm10\n"
"vbroadcastss 136(%0), %%ymm11\n"
"vbroadcastss 168(%0), %%ymm12\n"
"vfmadd231ps %%ymm13, %%ymm8, %%ymm0\n"
"vfmadd231ps %%ymm14, %%ymm8, %%ymm1\n"
"vfmadd231ps %%ymm9, %%ymm8, %%ymm2\n"
"vfmadd231ps %%ymm10, %%ymm8, %%ymm3\n"
"vfmadd231ps %%ymm11, %%ymm8, %%ymm4\n"
"vfmadd231ps %%ymm12, %%ymm8, %%ymm5\n"
"vbroadcastss 200(%0), %%ymm13\n"
"vbroadcastss 232(%0), %%ymm14\n"
"vfmadd231ps %%ymm13, %%ymm8, %%ymm6\n"
"vfmadd231ps %%ymm14, %%ymm8, %%ymm7\n"
// 4
"vmovups 96(%1), %%ymm15\n"
"vbroadcastss 12(%0), %%ymm9\n"
"vbroadcastss 44(%0), %%ymm10\n"
"vbroadcastss 76(%0), %%ymm11\n"
"vbroadcastss 108(%0), %%ymm12\n"
"vbroadcastss 140(%0), %%ymm13\n"
"vbroadcastss 172(%0), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm15, %%ymm0\n"
"vfmadd231ps %%ymm10, %%ymm15, %%ymm1\n"
"vfmadd231ps %%ymm11, %%ymm15, %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm3\n"
"vfmadd231ps %%ymm13, %%ymm15, %%ymm4\n"
"vfmadd231ps %%ymm14, %%ymm15, %%ymm5\n"
"vbroadcastss 204(%0), %%ymm9\n"
"vbroadcastss 236(%0), %%ymm10\n"
"vfmadd231ps %%ymm9, %%ymm15, %%ymm6\n"
"vfmadd231ps %%ymm10, %%ymm15, %%ymm7\n"
// 5
"vmovups 128(%1), %%ymm8\n"
"vbroadcastss 16(%0), %%ymm11\n"
"vbroadcastss 48(%0), %%ymm12\n"
"vbroadcastss 80(%0), %%ymm13\n"
"vbroadcastss 112(%0), %%ymm14\n"
"vbroadcastss 144(%0), %%ymm9\n"
"vbroadcastss 176(%0), %%ymm10\n"
"vfmadd231ps %%ymm11, %%ymm8, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm8, %%ymm1\n"
"vfmadd231ps %%ymm13, %%ymm8, %%ymm2\n"
"vfmadd231ps %%ymm14, %%ymm8, %%ymm3\n"
"vfmadd231ps %%ymm9, %%ymm8, %%ymm4\n"
"vfmadd231ps %%ymm10, %%ymm8, %%ymm5\n"
"vbroadcastss 208(%0), %%ymm11\n"
"vbroadcastss 240(%0), %%ymm12\n"
"vfmadd231ps %%ymm11, %%ymm8, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm8, %%ymm7\n"
// 6
"vmovups 160(%1), %%ymm15\n"
"vbroadcastss 20(%0), %%ymm13\n"
"vbroadcastss 52(%0), %%ymm14\n"
"vbroadcastss 84(%0), %%ymm9\n"
"vbroadcastss 116(%0), %%ymm10\n"
"vbroadcastss 148(%0), %%ymm11\n"
"vbroadcastss 180(%0), %%ymm12\n"
"vfmadd231ps %%ymm13, %%ymm15, %%ymm0\n"
"vfmadd231ps %%ymm14, %%ymm15, %%ymm1\n"
"vfmadd231ps %%ymm9, %%ymm15, %%ymm2\n"
"vfmadd231ps %%ymm10, %%ymm15, %%ymm3\n"
"vfmadd231ps %%ymm11, %%ymm15, %%ymm4\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n"
"vbroadcastss 212(%0), %%ymm13\n"
"vbroadcastss 244(%0), %%ymm14\n"
"vfmadd231ps %%ymm13, %%ymm15, %%ymm6\n"
"vfmadd231ps %%ymm14, %%ymm15, %%ymm7\n"
// 7
"vmovups 192(%1), %%ymm8\n"
"vbroadcastss 24(%0), %%ymm9\n"
"vbroadcastss 56(%0), %%ymm10\n"
"vbroadcastss 88(%0), %%ymm11\n"
"vbroadcastss 120(%0), %%ymm12\n"
"vbroadcastss 152(%0), %%ymm13\n"
"vbroadcastss 184(%0), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n"
"vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n"
"vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n"
"vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n"
"vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n"
"vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n"
"vbroadcastss 216(%0), %%ymm9\n"
"vbroadcastss 248(%0), %%ymm10\n"
"vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n"
"vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n"
// 8
"vmovups 224(%1), %%ymm15\n"
"vbroadcastss 28(%0), %%ymm11\n"
"vbroadcastss 60(%0), %%ymm12\n"
"vbroadcastss 92(%0), %%ymm13\n"
"vbroadcastss 124(%0), %%ymm14\n"
"vbroadcastss 156(%0), %%ymm9\n"
"vbroadcastss 188(%0), %%ymm10\n"
"vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n"
"vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n"
"vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n"
"vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n"
"vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n"
"vbroadcastss 220(%0), %%ymm11\n"
"vbroadcastss 252(%0), %%ymm12\n"
"vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n"
"vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n"
:
: "r"(src), "r"(weight)
: "memory");
#else
// 这一部分寄存器不够用,编译器优化的很差,需要写汇编进寄存器重排!
for (int j = 0; j < C8NUM; ++j) {
__m256 weight_data = _mm256_loadu_ps(weight + j * C8NUM);
dst1 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j)), dst1);
dst2 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C8NUM)), dst2);
dst3 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C16NUM)), dst3);
dst4 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C24NUM)), dst4);
dst5 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C32NUM)), dst5);
dst6 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C40NUM)), dst6);
dst7 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C48NUM)), dst7);
dst8 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C56NUM)), dst8);
}
#endif
src += C64NUM;
weight += C64NUM;
}
_mm256_storeu_ps(dst, dst1);
_mm256_storeu_ps(dst + C8NUM, dst2);
_mm256_storeu_ps(dst + C16NUM, dst3);
_mm256_storeu_ps(dst + C24NUM, dst4);
_mm256_storeu_ps(dst + C32NUM, dst5);
_mm256_storeu_ps(dst + C40NUM, dst6);
_mm256_storeu_ps(dst + C48NUM, dst7);
_mm256_storeu_ps(dst + C56NUM, dst8);
dst += cal_num;
}
}
原始代码块:
for (int ic = 0; ic < deep; ++ic) {
_mm256 s1 = _mm256_set1_ps(src[ic]);
_mm256 s2 = _mm256_set1_ps((src + deep)[ic]);
_mm256 s3 = _mm256_set1_ps((src + 2 * deep)[ic]);
_mm256 weight_data = _mm256_loadu_ps(weight + 0 * C8NUM);
dst_data[0] = _mm256_fmadd_ps(s1, weight_data, dats_data[0]);
dst_data[4] = _mm256_fmadd_ps(s2, weight_data, dats_data[4]);
dst_data[8] = _mm256_fmadd_ps(s3, weight_data, dats_data[8]);
}
反汇编代码:
http://godbolt.org/
// s1 s2 s3读取数据么有间接寻址哦!
并未拿寄存器的间接寻址展开(s1, s3, s2),而是将下标存到寄存器中,这在高版本的编译器中竟然也没有解决!!
如果是6x16block,该问题更加严重,我们的dst和src都需要存储到寄存器,至少需要6个通用寄存器,而通过寄存器间接寻址我们需要2个就能解决问题!(尤其在卷积资源匮乏情况下!)