AVX介绍

1 AVX 介绍

1.1 SIMD

SIMD:Single Instruction Multiple Data,单指令多数据,一个指令可以控制多个数据进行操作。

最简单的例子,在向量加法中,对每一个维度的值,都要进行加法运算:

// a=[a1, a2, a3, a4], b=[b1, b2, b3, b4]
sum[0]=a[0]+b[0];
sum[1]=a[1]+b[1];
sum[2]=a[2]+b[2];
sum[3]=a[3]+b[3];

在这里,使用的是单指令单数据(SISD)的处理方式,要进行四次加法运算,就当真是进行了四次加法运算,使用了四次加法指令。

那么,也可以使用单指令多数据(SIMD)的方式来处理:

sum_vector4 = a_vector4 + b_vector4;

进行4次加法只需要进行一个长度为4的向量加法,只是用了一次向量加法的指令。

在SIMD指令集中,可以控制若干个大寄存器,把这些寄存器中的数据按照某些规则进行统一的操作,相当于一条指令可以完成好几次重复运算,从而达到加快运算速度的效果。这里没有使用汇编指令,而是利用所谓的Intrinsics(内置)函数。

1.2 Intrinsics:直接映射到汇编的函数(指令)

(Intel® Intrinsics Guide)[https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html]

1.3 AVX 简介

在 CPU 中,有一些诸如 xmm,ymm,zmm 的寄存器,它们是用于SIMD指令的基础寄存器。为了充分利用寄存器,它们既可以用于整数计算,也可以用于浮点数计算(32位和64位都可以),为了区分寄存器的实际用途,指令具有这样的格式:

_mm256_add_ps
// 指令前缀,mm代表是SIMD指令集,256代表使用的寄存器宽度是256位
// 指令作用,add代表对每个元素进行加法运算
// 操作数类型,ps解释成packed single(float),即单浮点数向量。pd解释为packed double,还有各种整数类型

// 为了防止混用了类型,预定义了三种类型:
__m256  
// 最常用的类型,位宽256,解释为8个float

__m256d 
// 位宽256,解释为4个double

__m256i
// 位宽256,解释为若干个整数,具体怎么做取决于指令(因为整数的按位拆分组合更加频繁)

指令与类型是相互匹配的。如果必要也可以用convert指令进行强制类型转换,它不产生机器码,只是告诉编译器这个寄存器的原值可以直接用于其他类型。

2 AVX 浮点系列指令简介

与普通结构的指令集类似,AVX系列指令也可以分为若干类型。

2.1 内存访问指令

f32x8_p源代码(局部)

static f32x8_p load_8floats(const void *arr) { return f32x8_p(_mm256_loadu_ps((float *)arr)); }
static f32x8_p load_1float_broadcast(float *arr) { return f32x8_p(_mm256_broadcast_ss(arr)); }
static f32x8_p load_4floats_broadcast(const __m128 *arr) { return f32x8_p(_mm256_broadcast_ps(arr)); }
static f32x8_p load_mask(float const *arr, __m256i mask) { return f32x8_p(_mm256_maskload_ps(arr, mask)); }

void store(float *a) { _mm256_storeu_ps(a, data); }
void store(f32x8_p *a) { _mm256_store_ps((float *)a, data); }
void load(float *a) { data = _mm256_loadu_ps(a); }
void load(f32x8_p *a) { data = _mm256_load_ps((float *)a); }

可以看到有若干store/load指令。同时由于SIMD指令的特殊性,给出了一个广播加载指令,可以用一个数据初始化所有值,在某些时候会用到。

2.2 算数运算指令

f32x8_p源代码(局部)

f32x8_p operator+(f32x8_p a) { return f32x8_p(_mm256_add_ps(data, a.data)); }
f32x8_p operator-(f32x8_p a) { return f32x8_p(_mm256_sub_ps(data, a.data)); }
f32x8_p operator*(f32x8_p a) { return f32x8_p(_mm256_mul_ps(data, a.data)); }
f32x8_p operator/(f32x8_p a) { return f32x8_p(_mm256_div_ps(data, a.data)); }
void operator+=(f32x8_p a) { data = _mm256_add_ps(data, a.data); }
void operator-=(f32x8_p a) { data = _mm256_sub_ps(data, a.data); }
void operator*=(f32x8_p a) { data = _mm256_mul_ps(data, a.data); }
void operator/=(f32x8_p a) { data = _mm256_div_ps(data, a.data); }
2.3 逻辑运算指令

各个数据类型下的逻辑运算并没有结果上的区别,虽然指令看上去可能不一样

f32x8_p operator&(f32x8_p a) { return f32x8_p(_mm256_and_ps(data, a.data)); }
void operator&=(f32x8_p a) { data = _mm256_and_ps(data, a.data); }
f32x8_p operator|(f32x8_p a) { return f32x8_p(_mm256_or_ps(data, a.data)); }
void operator|=(f32x8_p a) { data = _mm256_or_ps(data, a.data); }
f32x8_p operator^(f32x8_p a) { return f32x8_p(_mm256_xor_ps(data, a.data)); }
void operator^=(f32x8_p a) { data = _mm256_xor_ps(data, a.data); }
static f32x8_p andnot(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_andnot_ps(a.data, b.data)); }
void do_andnot(f32x8_p a) { data = _mm256_andnot_ps(data, a.data); }
2.4 高级算数运算指令

对数据进行某种不统一但是有规律的算数运算:

// {a[0]-b[0], a[1]+b[1], a[2]-b[2], a[3]+b[3]...}
void addsub(f32x8_p a) { data = _mm256_addsub_ps(data, a.data); }
static f32x8_p addsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_addsub_ps(a.data, b.data)); }
// {a[0]+a[1], a[2]+a[3], b[0]+b[1], b[2]+b[3]...}
void hadd(f32x8_p a) { data = _mm256_hadd_ps(data, a.data); }
static f32x8_p hadd(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hadd_ps(a.data, b.data)); }
// {a[0]-a[1], a[2]-a[3], b[0]-b[1], b[2]-b[3]...}
void hsub(f32x8_p a) { data = _mm256_hsub_ps(data, a.data); }
static f32x8_p hsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hsub_ps(a.data, b.data)); }
// {a[0]*mul[0]+add[0], a[1]*mul[1]+add[1]...}
void mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }
void fmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }
static f32x8_p mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fmadd_ps(a.data, mul.data, add.data)); }
// {a[0]*mul[0]-add[0], a[1]*mul[1]-add[1]...}
void mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }
void fmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }
static f32x8_p mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fmsub_ps(a.data, mul.data, sub.data)); }
// {-a[0]*mul[0]+add[0], -a[1]*mul[1]+add[1]...}
void neg_mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }
void fnmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }
static f32x8_p neg_mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fnmadd_ps(a.data, mul.data, add.data)); }
// {-a[0]*mul[0]-add[0], -a[1]*mul[1]-add[1]...}
void neg_mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }
void fnmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }
static f32x8_p neg_mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fnmsub_ps(a.data, mul.data, sub.data)); }
// {a[0] * mul[0] + addsub[0], a[1] * mul[1] - addsub[1]...}
void mul_addsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }
void fmaddsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }
static f32x8_p mul_addsub(f32x8_p a, f32x8_p mul, f32x8_p addsub) { return f32x8_p(_mm256_fmaddsub_ps(a.data, mul.data, addsub.data)); }
// {a[0] * mul[0] - subadd[0], a[1] * mul[1] + subadd[1]...}
void mul_subadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }
void fmsubadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }
static f32x8_p mul_subadd(f32x8_p a, f32x8_p mul, f32x8_p subadd) { return f32x8_p(_mm256_fmsubadd_ps(a.data, mul.data, subadd.data)); }

交错加减运算、横向加减法运算(向量内部的相邻元素做加减法)、带乘法的三操作数加减运算。

2.5 数据重排指令

将一个寄存器的数据按照一定的规则重排,或者将两个寄存器的数据按照一定规则重排到一个寄存器;有时候需要运算的数据并没有正对着,需要重新排列一下。

需要注意的是,作为整数的重排规则需要是编译期常量,因为这是机器码的一部分。因此需要使用模板参数而非普通函数参数。

这里的注释中,对imm8的运算符[]是按位访问的,imm[1:0]代表由imm的最低两位组成的0~3的整数值。

// {a[2], b[2], a[3], b[3], a[6], b[6], a[7], b[7]}
static f32x8_p unpack_high(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpackhi_ps(a.data, b.data)); }
// {a[0], b[0], a[1], b[1], a[4], b[4], a[5], b[5]}
static f32x8_p unpack_low(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpacklo_ps(a.data, b.data)); }

// for j = 0 to 7, ret[j] = imm8[j]?a[j]:b[j]
template <uint8_t _imm8>
void blend(f32x8_p a) { data = _mm256_blend_ps(data, a.data, _imm8); }
template <uint8_t _imm8>
static f32x8_p blend(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_blend_ps(a.data, b.data, _imm8)); }

// for j = 0 to 7, ret[j] = mask[j].signbit?a[j]:b[j]
void blend(f32x8_p a, f32x8_p mask) { data = _mm256_blendv_ps(data, a.data, mask.data); }
static f32x8_p blend(f32x8_p a, f32x8_p b, f32x8_p mask) { return f32x8_p(_mm256_blendv_ps(a.data, b.data, mask.data)); }

// see f32x4_p::shuffle. do the same shuffle for both high and low 128bit
// {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}
template <uint8_t _imm8>
void shuffle(f32x8_p a) { data = _mm256_shuffle_ps(data, a.data, _imm8); }

// {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}
template <uint8_t _imm8>
void permute() { data = _mm256_permute_ps(data, _imm8); }

// ret.low  = imm8[2]? 0 : switch imm8[1:0] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}
// ret.high = imm8[6]? 0 : switch imm8[5:4] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}
template <uint8_t _imm8>
void permute2f128(f32x8_p a) { data = _mm256_permute2f128_ps(data, a.data, _imm8); }

void move_high_dup() { data = _mm256_movehdup_ps(data); }
void move_odd2even() { data = _mm256_movehdup_ps(data); }
f32x8_p copy_odd2even() { return f32x8_p(_mm256_movehdup_ps(data)); }
void move_low_dup() { data = _mm256_moveldup_ps(data); }
void move_even2odd() { data = _mm256_moveldup_ps(data); }
f32x8_p copy_even2odd() { return f32x8_p(_mm256_moveldup_ps(data)); }
2.6 常用函数指令

取倒数、开平方根、取平方根的倒数:

void rcp() { data = _mm256_rcp_ps(data); }
void sqrt() { data = _mm256_sqrt_ps(data); }
void rsqrt() { data = _mm256_rsqrt_ps(data); }
2.7 掩码指令

通过掩码指定真正参与运算的操作数是来自哪个向量。

3 复数乘法

这里讨论的是最经典的复数存储格式,实部与虚部交错存储:

{c[0].re, c[0].im,  c[1].re, c[1].im,  c[2].re, c[2].im,  c[3].re, c[3].im}

为简便起见,我们暂时只看第一个复数 a + b i a+bi a+bi:

var a_bi={a, b, ...}

把它乘以另一个复数 c + d i c+di c+di

var c_di={c, d, ...}

希望得到的结果是 ( a c − b d ) + ( a d + b c ) i (ac-bd) + (ad+bc)i (acbd)+(ad+bc)i

{ac-bd, ad+bc, ...}
3.1 方法1
static fc32x4_p multiply_complex_v0(f32x8_p a_bI, f32x8_p c_dI)
{ // real: a*c - b*d, imag: a*d + b*c
    f32x8_p ac_bdI = a_bI * c_dI;
    f32x8_p ad_bcI = a_bI * c_dI.reordered<0b10'11'00'01>();
    ac_bdI.hsub(ac_bdI);
    ad_bcI.hadd(ad_bcI);
    return fc32x4_p(ac_bdI
                        .remixed<0b11'10'01'00>(ad_bcI) // {r0,r1,i0,i1}
                        .reordered<0b11'01'10'00>());   // {r0,i0,r1,i1}
}
  • 说明
    • ac_bdI = a_bI * c_dI :它的值为a*c,b*d, ...,实际上就是实部的两个值

    • 重排c_dI, 重排规则 0b10'11'00'01 的含义如下:

      0b10'11'00'01
          dst[0]=src[0b01]
          dst[1]=src[0b00]
          dst[2]=src[0b11]
          dst[3]=src[0b10]
      

      实际上就是 交换了它的实部与虚部,得到了: d + c i d+ci d+ci

    • a_bI * c_dI.reordered<0b10'11'00'01>() = ad_bcI = a_bI * d_cI, 它的值为 ad, bc, ..., 实际上就是虚部的两个值

    • 横向减法得到实部,第二操作数为自身

      ac_bdI.hsub(ac_bdI);
      

      它的效果是,对每个128bit:

      dst[0]=a[1]-a[0]
      dst[1]=a[3]-a[2]
      dst[2]=b[1]-b[0]
      dst[3]=b[3]-b[2]
      

      得到的结果是这样的:

      c[0].re, c[1].re, c[0].re, c[1].re, ...
      
    • 进行横向加法得到虚部

      ad_bcI.hadd(ad_bcI);
      

      虚部:

      c[0].im, c[1].im, c[0].im, c[1].im, ...
      
    • 对实部和虚部进行适当的重排即可得到答案
      合并,去掉重复的部分:结果的前两路来自第一操作数,后两路来自第二操作数

       return fc32x4_p(ac_bdI.remixed<0b11'10'01'00>(ad_bcI) // {r0,r1,i0,i1}
      

      实部挨在一起了:

      c[0].re, c[1].re, c[0].im, c[1].im, ...
      

      交换一下[1]号元素和[2]号元素得到最终结果:

       .reordered<0b11'01'10'00>());   // {r0,i0,r1,i1}
      
3.2 方法二
// complex mul complex, use addsub, cost about 90% time of _v0
static fc32x4_p multiply_complex_v1(f32x8_p a_bI, f32x8_p c_dI)
{ // real: a*c - b*d, imag: a*d + b*c
    f32x8_p a_aI = a_bI.copy_even2odd();
    f32x8_p b_bI = a_bI.copy_odd2even();
    f32x8_p ac_adI = a_aI * c_dI;
    f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();
    return fc32x4_p(fc32x4_p::addsub(ac_adI, bd_bcI));
}

这次可以逆着来构造,考虑到前面有一个交错加减的函数,用来计算复数乘法比较合适。为了使用交错加减得到复数相乘的结果

{ac-bd, ad+bc, ...}

可以通过这样的运算得到:

{ac, bc}(+/-){ad, bd}
// 或者
{ac, ad}(+/-){bd, bc}

观察第二个式子不难发现:

{ac, ad} = {a, a} * {c, d}
{bd, bc} = {b, b} * {d, c}

// -->
{a, a}, {c, d}, {b, b}, {d, c}  // 这个可以通过重排函数得到
3.3 方法三

使用乘法后交错加减指令优化一下下方案二

// complex mul complex, use mul_addsub, cost about 98% time of _v1
static fc32x4_p multiply_complex_v2(f32x8_p a_bI, f32x8_p c_dI)
{ // real: a*c - b*d, imag: a*d + b*c
    f32x8_p a_aI = a_bI.copy_even2odd();
    f32x8_p b_bI = a_bI.copy_odd2even();
    //f32x8_p ac_adI = a_aI * c_dI;
    f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();
    return fc32x4_p::mul_addsub(a_aI, c_dI, bd_bcI);
}
3.4 继续优化

这个封装类因为有源码实现,所以可以直接优化并内联。如果需要封装SIMD指令的函数,可以考虑__vectorcall。__vectorcall使用xmm,ymm等寄存器传递参数和返回值。如下:

_declspec(dllexport) auto __vectorcall multiply_complex_v0(__m256 a, __m256 b)
{
    return fc32x4_p::multiply_complex_v0(a, b).data;
}
//_declspec(dllexport) auto __vectorcall multiply_complex_v2(__m256 a, __m256 b)
        //{
C5 FC 10 E0          vmovups     ymm4,ymm0  
            //return fc32x4_p::multiply_complex_v2(a, b).data;
C5 FE 16 D0          vmovshdup   ymm2,ymm0  
C4 E3 7D 04 D9 B1    vpermilps   ymm3,ymm1,0B1h  //看这里,立即数0B1h写死在代码里的
C5 E4 59 C2          vmulps      ymm0,ymm3,ymm2  
C5 FE 12 E4          vmovsldup   ymm4,ymm4  
C4 E2 5D B6 C1       vfmaddsub231ps ymm0,ymm4,ymm1  
        //}
C3                   ret  

AVX256 源码

#pragma once
#ifndef bionukg_SIMD_h
#define bionukg_SIMD_h

// sse / avx
#include <stdint.h>
#include <xmmintrin.h> //__m128, f32x4
#include <emmintrin.h> //__m128i,__m128d
#include <immintrin.h> //__m256 series


#ifdef namespace_bionukg
namespace bionukg
{
#endif

    struct f32x4_b; // basic float32x4
    struct f32x4_p; // packed float32x4
    struct f32x4_s; // single float32x4


    struct f32x4_b
    {
    public:
        union
        {
            __m128 data;
            float f32x4[4];
        };
        f32x4_b() : data(_mm_setzero_ps()) {}
        f32x4_b(__m128 data) : data(data) {}
        f32x4_b(float a, float b, float c, float d) : data(_mm_setr_ps(a, b, c, d)) {}
        f32x4_b(float a) : data(_mm_set1_ps(a)) {}
        f32x4_b(float *a) : data(_mm_loadu_ps(a)) {}
        f32x4_b(f32x4_b *a) : data(_mm_load_ps((float *)a)) {}

        void store(float *a) { _mm_storeu_ps(a, data); }
        void store(f32x4_b *a) { _mm_store_ps((float *)a, data); }
        void load(float *a) { data = _mm_loadu_ps(a); }
        void load(f32x4_b *a) { data = _mm_load_ps((float *)a); }
        float operator[](uint8_t idx) const { return data.m128_f32[idx]; }
        float &operator[](uint8_t idx) { return data.m128_f32[idx]; }
    };

    struct f32x4_s : public f32x4_b
    {
    public:
        f32x4_s() : f32x4_b() {}
        f32x4_s(__m128 data) : f32x4_b(data) {}
        f32x4_s(f32x4_b a) : f32x4_b(a) {}
        f32x4_s(float a, float b, float c, float d) : f32x4_b(a, b, c, d) {}
        f32x4_s(float a) : f32x4_b(a) {}
        f32x4_s(float *a) : f32x4_b(a) {}
        f32x4_s(f32x4_b *a) : f32x4_b(a) {}

        int32_t get_int32() const { return _mm_cvtss_si32(data); }
        int32_t get_int32_trunc() const { return _mm_cvttss_si32(data); }
        int64_t get_int64() const { return _mm_cvtss_si64(data); }
        int64_t get_int64_trunc() const { return _mm_cvttss_si64(data); }
        void put_int32(int32_t a) { data = _mm_cvtsi32_ss(data, a); }
        void put_int64(int64_t a) { data = _mm_cvtsi64_ss(data, a); }
        float get_float() const { return _mm_cvtss_f32(data); }

        f32x4_s sqrt() { return f32x4_s(_mm_sqrt_ss(data)); }
        void do_sqrt() { data = _mm_sqrt_ss(data); }
    };

    struct f32x4_p : public f32x4_b
    {
    public:
        f32x4_p() : f32x4_b() {}
        f32x4_p(__m128 data) : f32x4_b(data) {}
        f32x4_p(f32x4_b a) : f32x4_b(a) {}
        f32x4_p(float a, float b, float c, float d) : f32x4_b(a, b, c, d) {}
        f32x4_p(float a) : f32x4_b(a) {}
        f32x4_p(float *a) : f32x4_b(a) {}
        f32x4_p(f32x4_b *a) : f32x4_b(a) {}

        f32x4_p copy() const { return f32x4_p(data); }
        f32x4_s as_single() const { return f32x4_s(data); }
        f32x4_s &as_single_do() { return reinterpret_cast<f32x4_s &>(data); }

        f32x4_p operator+(f32x4_p a) { return f32x4_p(_mm_add_ps(data, a.data)); }
        f32x4_p operator+(f32x4_s a) { return f32x4_p(_mm_add_ss(data, a.data)); }
        f32x4_p operator-(f32x4_p a) { return f32x4_p(_mm_sub_ps(data, a.data)); }
        f32x4_p operator-(f32x4_s a) { return f32x4_p(_mm_sub_ss(data, a.data)); }
        f32x4_p operator*(f32x4_p a) { return f32x4_p(_mm_mul_ps(data, a.data)); }
        f32x4_p operator*(f32x4_s a) { return f32x4_p(_mm_mul_ss(data, a.data)); }
        f32x4_p operator/(f32x4_p a) { return f32x4_p(_mm_div_ps(data, a.data)); }
        f32x4_p operator/(f32x4_s a) { return f32x4_p(_mm_div_ss(data, a.data)); }
        void operator+=(f32x4_p a) { data = _mm_add_ps(data, a.data); }
        void operator+=(f32x4_s a) { data = _mm_add_ss(data, a.data); }
        void operator-=(f32x4_p a) { data = _mm_sub_ps(data, a.data); }
        void operator-=(f32x4_s a) { data = _mm_sub_ss(data, a.data); }
        void operator*=(f32x4_p a) { data = _mm_mul_ps(data, a.data); }
        void operator*=(f32x4_s a) { data = _mm_mul_ss(data, a.data); }
        void operator/=(f32x4_p a) { data = _mm_div_ps(data, a.data); }
        void operator/=(f32x4_s a) { data = _mm_div_ss(data, a.data); }

        f32x4_p sqrt() const { return f32x4_p(_mm_sqrt_ps(data)); }
        void do_sqrt() { data = _mm_sqrt_ps(data); }
        void do_sqrt_single() { data = _mm_sqrt_ss(data); }

        f32x4_p rcp() const { return f32x4_p(_mm_rcp_ps(data)); }
        void do_rcp() { data = _mm_rcp_ps(data); }
        void do_rcp_single() { data = _mm_rcp_ss(data); }

        f32x4_p rsqrt() const { return f32x4_p(_mm_rsqrt_ps(data)); }
        void do_rsqrt() { data = _mm_rsqrt_ps(data); }
        void do_rsqrt_single() { data = _mm_rsqrt_ss(data); }

        static f32x4_p minimum(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_min_ps(a.data, b.data)); }
        static f32x4_p minimum(f32x4_p a, f32x4_s b) { return f32x4_p(_mm_min_ss(a.data, b.data)); }
        static f32x4_p maximum(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_max_ps(a.data, b.data)); }
        static f32x4_p maximum(f32x4_p a, f32x4_s b) { return f32x4_p(_mm_max_ss(a.data, b.data)); }

        f32x4_p operator&(f32x4_p a) { return f32x4_p(_mm_and_ps(data, a.data)); }
        void operator&=(f32x4_p a) { data = _mm_and_ps(data, a.data); }
        f32x4_p operator|(f32x4_p a) { return f32x4_p(_mm_or_ps(data, a.data)); }
        void operator|=(f32x4_p a) { data = _mm_or_ps(data, a.data); }
        f32x4_p operator^(f32x4_p a) { return f32x4_p(_mm_xor_ps(data, a.data)); }
        void operator^=(f32x4_p a) { data = _mm_xor_ps(data, a.data); }
        static f32x4_p andnot(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_andnot_ps(a.data, b.data)); }
        void do_andnot(f32x4_p a) { data = _mm_andnot_ps(data, a.data); }

        f32x4_p operator==(f32x4_p a) { return f32x4_p(_mm_cmpeq_ps(data, a.data)); }
        int operator==(f32x4_s a) { return _mm_comieq_ss(data, a.data); }
        f32x4_p operator!=(f32x4_p a) { return f32x4_p(_mm_cmpneq_ps(data, a.data)); }
        int operator!=(f32x4_s a) { return _mm_comineq_ss(data, a.data); }
        f32x4_p operator<(f32x4_p a) { return f32x4_p(_mm_cmplt_ps(data, a.data)); }
        int operator<(f32x4_s a) { return _mm_comilt_ss(data, a.data); }
        f32x4_p operator<=(f32x4_p a) { return f32x4_p(_mm_cmple_ps(data, a.data)); }
        int operator<=(f32x4_s a) { return _mm_comile_ss(data, a.data); }
        f32x4_p operator>(f32x4_p a) { return f32x4_p(_mm_cmpgt_ps(data, a.data)); }
        int operator>(f32x4_s a) { return _mm_comigt_ss(data, a.data); }
        f32x4_p operator>=(f32x4_p a) { return f32x4_p(_mm_cmpge_ps(data, a.data)); }
        int operator>=(f32x4_s a) { return _mm_comige_ss(data, a.data); }
        f32x4_p has_NAN() { return f32x4_p(_mm_cmpunord_ps(data, data)); }
        f32x4_p has_NAN(f32x4_p a) { return f32x4_p(_mm_cmpunord_ps(data, a.data)); }
        f32x4_p not_NAN() { return f32x4_p(_mm_cmpord_ps(data, data)); }
        f32x4_p not_NAN(f32x4_p a) { return f32x4_p(_mm_cmpord_ps(data, a.data)); }

        // {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}
        template <uint8_t _imm8>
        void shuffle(f32x4_p a) { data = _mm_shuffle_ps(data, a.data, _imm8); }
        /*
        DEFINE SELECT4(src, control) {
        CASE(control[1:0]) OF
        0:	tmp[31:0] := src[31:0]
        1:	tmp[31:0] := src[63:32]
        2:	tmp[31:0] := src[95:64]
        3:	tmp[31:0] := src[127:96]
        ESAC
        RETURN tmp[31:0]
        }
        dst[31:0] := SELECT4(a[127:0], imm8[1:0])
        dst[63:32] := SELECT4(a[127:0], imm8[3:2])
        dst[95:64] := SELECT4(b[127:0], imm8[5:4])
        dst[127:96] := SELECT4(b[127:0], imm8[7:6])
        */
        // {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}
        template <uint8_t _imm8>
        void permute() { data = _mm_permute_ps(data, _imm8); }
        /*
        DEFINE SELECT4(src, control) {
        CASE(control[1:0]) OF
        0:	tmp[31:0] := src[31:0]
        1:	tmp[31:0] := src[63:32]
        2:	tmp[31:0] := src[95:64]
        3:	tmp[31:0] := src[127:96]
        ESAC
        RETURN tmp[31:0]
        }
        dst[31:0] := SELECT4(a[127:0], imm8[1:0])
        dst[63:32] := SELECT4(a[127:0], imm8[3:2])
        dst[95:64] := SELECT4(a[127:0], imm8[5:4])
        dst[127:96] := SELECT4(a[127:0], imm8[7:6])
        dst[MAX:128] := 0
        */
        // {a[3], b[3], a[4], b[4]}
        static f32x4_p unpack_high(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_unpackhi_ps(a.data, b.data)); }
        void unpack_high(f32x4_p a) { data = _mm_unpackhi_ps(data, a.data); }
        // {a[0], b[0], a[1], b[1]}
        static f32x4_p unpack_low(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_unpacklo_ps(a.data, b.data)); }
        void unpack_low(f32x4_p a) { data = _mm_unpacklo_ps(data, a.data); }
        // {a[0], a[1], mem[0], mem[1]}
        static f32x4_p load_high(f32x4_p a, float *mem_addr) { return f32x4_p(_mm_loadh_pi(a.data, (__m64 *)mem_addr)); }
        void load_high(float *mem_addr) { data = _mm_loadh_pi(data, (__m64 *)mem_addr); }
        void store_high(float *mem_addr) { _mm_storeh_pi((__m64 *)mem_addr, data); }
        // {mem[0], mem[1], a[0], a[1]}
        static f32x4_p load_low(f32x4_p a, float *mem_addr) { return f32x4_p(_mm_loadl_pi(a.data, (__m64 *)mem_addr)); }
        void load_low(float *mem_addr) { data = _mm_loadl_pi(data, (__m64 *)mem_addr); }
        void store_low(float *mem_addr) { _mm_storel_pi((__m64 *)mem_addr, data); }
        // {b[2], b[3], a[2], a[3]}
        static f32x4_p move_high2low(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_movehl_ps(a.data, b.data)); }
        void move_high2low(f32x4_p a) { data = _mm_movehl_ps(data, a.data); }
        void move_h2l(f32x4_p a) { data = _mm_movehl_ps(data, a.data); }
        // {a[0], a[1], b[0], b[1]}
        static f32x4_p move_low2high(f32x4_p a, f32x4_p b) { return f32x4_p(_mm_movelh_ps(a.data, b.data)); }
        void move_low2high(f32x4_p a) { data = _mm_movelh_ps(data, a.data); }
        void move_l2h(f32x4_p a) { data = _mm_movelh_ps(data, a.data); }

        int get_mask() const { return _mm_movemask_ps(data); }
        int get_signs() const { return _mm_movemask_ps(data); }
    };

    struct f32x8_p
    {
    public:
        union
        {
            __m256 data;
            float f32x8[8];
        };
        f32x8_p() : data(_mm256_setzero_ps()) {}
        f32x8_p(__m256 data) : data(data) {}
        f32x8_p(f32x4_p high, f32x4_p low) : data(_mm256_set_m128(high.data, low.data)) {}
        f32x8_p(float a, float b, float c, float d, float e = 0.0f, float f = 0.0f, float g = 0.0f, float h = 0.0f) : data(_mm256_setr_ps(a, b, c, d, e, f, g, h)) {}
        f32x8_p(float a) : data(_mm256_set1_ps(a)) {}
        f32x8_p(float *a) : data(_mm256_load_ps((float *)a)) {}

        __m256d cast_packed_double() const { return _mm256_castps_pd(data); }
        __m256d cast_packed_double() { return _mm256_castps_pd(data); }
        __m256d cast_packed_f64() const { return _mm256_castps_pd(data); }
        __m256d cast_packed_f64() { return _mm256_castps_pd(data); }
        __m256d cast_pd() const { return _mm256_castps_pd(data); }
        __m256d cast_pd() { return _mm256_castps_pd(data); }
        __m256i cast_packed_int() const { return _mm256_castps_si256(data); }
        __m256i cast_packed_int() { return _mm256_castps_si256(data); }
        __m256i cast_si() const { return _mm256_castps_si256(data); }
        __m256i cast_si() { return _mm256_castps_si256(data); }
        __m128 cast_down() const { return _mm256_castps256_ps128(data); }
        __m128 cast_down() { return _mm256_castps256_ps128(data); }
        template <uint8_t _imm3 = 0b000>
        __m128i convert_packed_half_float() const { return _mm256_cvtps_ph(data, _imm3); }
        template <uint8_t _imm3 = 0b000>
        __m128i convert_packed_f16() const { return _mm256_cvtps_ph(data, _imm3); }
        template <uint8_t _imm3 = 0b000>
        __m128i convert_packed_half_float() { return _mm256_cvtps_ph(data, _imm3); }
        template <uint8_t _imm3 = 0b000>
        __m128i convert_packed_f16() { return _mm256_cvtps_ph(data, _imm3); }

        static f32x8_p convert_from_packed_f16(__m128i a) { return f32x8_p(_mm256_cvtph_ps(a)); }

        static f32x8_p load_8floats(const void *arr) { return f32x8_p(_mm256_loadu_ps((float *)arr)); }
        static f32x8_p load_1float_broadcast(float *arr) { return f32x8_p(_mm256_broadcast_ss(arr)); }
        static f32x8_p load_4floats_broadcast(const __m128 *arr) { return f32x8_p(_mm256_broadcast_ps(arr)); }
        static f32x8_p load_mask(float const *arr, __m256i mask) { return f32x8_p(_mm256_maskload_ps(arr, mask)); }

        // {a[2], b[2], a[3], b[3], a[6], b[6], a[7], b[7]}
        static f32x8_p unpack_high(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpackhi_ps(a.data, b.data)); }
        // {a[0], b[0], a[1], b[1], a[4], b[4], a[5], b[5]}
        static f32x8_p unpack_low(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_unpacklo_ps(a.data, b.data)); }

        void store(float *a) { _mm256_storeu_ps(a, data); }
        void store(f32x8_p *a) { _mm256_store_ps((float *)a, data); }
        void load(float *a) { data = _mm256_loadu_ps(a); }
        void load(f32x8_p *a) { data = _mm256_load_ps((float *)a); }

        float operator[](uint8_t idx) const { return data.m256_f32[idx]; }
        float &operator[](uint8_t idx) { return data.m256_f32[idx]; }

        f32x8_p copy() const { return f32x8_p(data); }

        f32x8_p operator+(f32x8_p a) { return f32x8_p(_mm256_add_ps(data, a.data)); }
        f32x8_p operator-(f32x8_p a) { return f32x8_p(_mm256_sub_ps(data, a.data)); }
        f32x8_p operator*(f32x8_p a) { return f32x8_p(_mm256_mul_ps(data, a.data)); }
        f32x8_p operator/(f32x8_p a) { return f32x8_p(_mm256_div_ps(data, a.data)); }
        void operator+=(f32x8_p a) { data = _mm256_add_ps(data, a.data); }
        void operator-=(f32x8_p a) { data = _mm256_sub_ps(data, a.data); }
        void operator*=(f32x8_p a) { data = _mm256_mul_ps(data, a.data); }
        void operator/=(f32x8_p a) { data = _mm256_div_ps(data, a.data); }

        f32x8_p operator&(f32x8_p a) { return f32x8_p(_mm256_and_ps(data, a.data)); }
        void operator&=(f32x8_p a) { data = _mm256_and_ps(data, a.data); }
        f32x8_p operator|(f32x8_p a) { return f32x8_p(_mm256_or_ps(data, a.data)); }
        void operator|=(f32x8_p a) { data = _mm256_or_ps(data, a.data); }
        f32x8_p operator^(f32x8_p a) { return f32x8_p(_mm256_xor_ps(data, a.data)); }
        void operator^=(f32x8_p a) { data = _mm256_xor_ps(data, a.data); }
        static f32x8_p andnot(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_andnot_ps(a.data, b.data)); }
        void do_andnot(f32x8_p a) { data = _mm256_andnot_ps(data, a.data); }

        // xor for equal makes all bits 0
        bool operator==(f32x8_p a)
        {
            auto int_result = _mm256_castps_si256(_mm256_xor_ps(data, a.data));
            // check if all bits are 0
            return _mm256_testz_si256(int_result, int_result) != 0;
        }

        // for j = 0 to 7, ret[j] = imm8[j]?a[j]:b[j]
        template <uint8_t _imm8>
        void blend(f32x8_p a) { data = _mm256_blend_ps(data, a.data, _imm8); }
        template <uint8_t _imm8>
        static f32x8_p blend(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_blend_ps(a.data, b.data, _imm8)); }

        // for j = 0 to 7, ret[j] = mask[j].signbit?a[j]:b[j]
        void blend(f32x8_p a, f32x8_p mask) { data = _mm256_blendv_ps(data, a.data, mask.data); }
        static f32x8_p blend(f32x8_p a, f32x8_p b, f32x8_p mask) { return f32x8_p(_mm256_blendv_ps(a.data, b.data, mask.data)); }

        /// advanced calculation


        // {a[0]-b[0], a[1]+b[1], a[2]-b[2], a[3]+b[3]...}
        void addsub(f32x8_p a) { data = _mm256_addsub_ps(data, a.data); }
        static f32x8_p addsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_addsub_ps(a.data, b.data)); }
        // {a[0]+a[1], a[2]+a[3], b[0]+b[1], b[2]+b[3]...}
        void hadd(f32x8_p a) { data = _mm256_hadd_ps(data, a.data); }
        static f32x8_p hadd(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hadd_ps(a.data, b.data)); }
        // {a[0]-a[1], a[2]-a[3], b[0]-b[1], b[2]-b[3]...}
        void hsub(f32x8_p a) { data = _mm256_hsub_ps(data, a.data); }
        static f32x8_p hsub(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_hsub_ps(a.data, b.data)); }
        // {a[0]*mul[0]+add[0], a[1]*mul[1]+add[1]...}
        void mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }
        void fmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fmadd_ps(data, mul.data, add.data); }
        static f32x8_p mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fmadd_ps(a.data, mul.data, add.data)); }
        // {a[0]*mul[0]-add[0], a[1]*mul[1]-add[1]...}
        void mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }
        void fmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fmsub_ps(data, mul.data, sub.data); }
        static f32x8_p mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fmsub_ps(a.data, mul.data, sub.data)); }
        // {-a[0]*mul[0]+add[0], -a[1]*mul[1]+add[1]...}
        void neg_mul_then_add(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }
        void fnmadd(f32x8_p mul, f32x8_p add) { data = _mm256_fnmadd_ps(data, mul.data, add.data); }
        static f32x8_p neg_mul_then_add(f32x8_p a, f32x8_p mul, f32x8_p add) { return f32x8_p(_mm256_fnmadd_ps(a.data, mul.data, add.data)); }
        // {-a[0]*mul[0]-add[0], -a[1]*mul[1]-add[1]...}
        void neg_mul_then_sub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }
        void fnmsub(f32x8_p mul, f32x8_p sub) { data = _mm256_fnmsub_ps(data, mul.data, sub.data); }
        static f32x8_p neg_mul_then_sub(f32x8_p a, f32x8_p mul, f32x8_p sub) { return f32x8_p(_mm256_fnmsub_ps(a.data, mul.data, sub.data)); }
        // {a[0] * mul[0] + addsub[0], a[1] * mul[1] - addsub[1]...}
        void mul_addsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }
        void fmaddsub(f32x8_p mul, f32x8_p addsub) { data = _mm256_fmaddsub_ps(data, mul.data, addsub.data); }
        static f32x8_p mul_addsub(f32x8_p a, f32x8_p mul, f32x8_p addsub) { return f32x8_p(_mm256_fmaddsub_ps(a.data, mul.data, addsub.data)); }
        // {a[0] * mul[0] - subadd[0], a[1] * mul[1] + subadd[1]...}
        void mul_subadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }
        void fmsubadd(f32x8_p mul, f32x8_p subadd) { data = _mm256_fmsubadd_ps(data, mul.data, subadd.data); }
        static f32x8_p mul_subadd(f32x8_p a, f32x8_p mul, f32x8_p subadd) { return f32x8_p(_mm256_fmsubadd_ps(a.data, mul.data, subadd.data)); }




        // {a[0]>b[0]?a[0]:b[0], a[1]>b[1]?a[1]:b[1]...}
        void maximum(f32x8_p a) { data = _mm256_max_ps(data, a.data); }
        static f32x8_p maximum(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_max_ps(a.data, b.data)); }

        // {a[0]<b[0]?a[0]:b[0], a[1]<b[1]?a[1]:b[1]...}
        void minimum(f32x8_p a) { data = _mm256_min_ps(data, a.data); }
        static f32x8_p minimum(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_min_ps(a.data, b.data)); }

        // see f32x4_p::shuffle. do the same shuffle for both high and low 128bit
        // {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}
        template <uint8_t _imm8>
        void shuffle(f32x8_p a) { data = _mm256_shuffle_ps(data, a.data, _imm8); }
        template <uint8_t _imm8>
        void remix(f32x8_p a) { data = _mm256_shuffle_ps(data, a.data, _imm8); }
        // {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}
        template <uint8_t _imm8>
        f32x8_p remixed(f32x8_p a) const { return f32x8_p(_mm256_shuffle_ps(data, a.data, _imm8)); }
        // {a[imm[1:0]], a[imm[3:2]], b[imm[5:4]], b[imm[7:6]]}
        template <uint8_t _imm8>
        f32x8_p remixed(f32x8_p a) { return f32x8_p(_mm256_shuffle_ps(data, a.data, _imm8)); }
        template <uint8_t _imm8>
        static f32x8_p shuffle(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_shuffle_ps(a.data, b.data, _imm8)); }
        // see f32x4_p::permute. do the same permute for both high and low 128bit
        void permute(f32x8_p a, __m256i b) { data = _mm256_permutevar_ps(a.data, b); }
        static f32x8_p permute(f32x8_p a, f32x8_p b, __m256i c) { return f32x8_p(_mm256_permutevar_ps(a.data, c)); }

        // {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}
        template <uint8_t _imm8>
        void permute() { data = _mm256_permute_ps(data, _imm8); }
        template <uint8_t _imm8>
        void reorder() { data = _mm256_permute_ps(data, _imm8); }
        // {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}
        template <uint8_t _imm8>
        f32x8_p reordered() const { return f32x8_p(_mm256_permute_ps(data, _imm8)); }
        // {a[imm[1:0]], a[imm[3:2]], a[imm[5:4]], a[imm[7:6]]}
        template <uint8_t _imm8>
        f32x8_p reordered() { return f32x8_p(_mm256_permute_ps(data, _imm8)); }
        template <uint8_t _imm8>
        f32x8_p permuted() { return f32x8_p(_mm256_permute_ps(data, _imm8)); }
        template <uint8_t _imm8>
        static f32x8_p permute(f32x8_p a) { return f32x8_p(_mm256_permute_ps(a.data, _imm8)); }

        // ret.low  = imm8[2]? 0 : switch imm8[1:0] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}
        // ret.high = imm8[6]? 0 : switch imm8[5:4] {case 0: a.low; case 1: a.high; case 2: b.low; case 3: b.high}
        template <uint8_t _imm8>
        void permute2f128(f32x8_p a) { data = _mm256_permute2f128_ps(data, a.data, _imm8); }
        template <uint8_t _imm8>
        void shuffle128(f32x8_p a) { data = permute2f128<_imm8>(data, a.data); }
        template <uint8_t _imm8>
        static f32x8_p permute2f128(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_permute2f128_ps(a.data, b.data, _imm8)); }
        template <uint8_t _imm8>
        static f32x8_p shuffle128(f32x8_p a, f32x8_p b) { return permute2f128<_imm8>(a.data, b.data); }

        enum compare_operation_imm8 : uint8_t
        {
            OP_CMP_EQ_OQ = 0,
            OP_CMP_LT_OS = 1,
            OP_CMP_LE_OS = 2,
            OP_CMP_UNORD_Q = 3,
            OP_CMP_NEQ_UQ = 4,
            OP_CMP_NLT_US = 5,
            OP_CMP_NLE_US = 6,
            OP_CMP_ORD_Q = 7,
            OP_CMP_EQ_UQ = 8,
            OP_CMP_NGE_US = 9,
            OP_CMP_NGT_US = 10,
            OP_CMP_FALSE_OQ = 11,
            OP_CMP_NEQ_OQ = 12,
            OP_CMP_GE_OS = 13,
            OP_CMP_GT_OS = 14,
            OP_CMP_TRUE_UQ = 15,
            OP_CMP_EQ_OS = 16,
            OP_CMP_LT_OQ = 17,
            OP_CMP_LE_OQ = 18,
            OP_CMP_UNORD_S = 19,
            OP_CMP_NEQ_US = 20,
            OP_CMP_NLT_UQ = 21,
            OP_CMP_NLE_UQ = 22,
            OP_CMP_ORD_S = 23,
            OP_CMP_EQ_US = 24,
            OP_CMP_NGE_UQ = 25,
            OP_CMP_NGT_UQ = 26,
            OP_CMP_FALSE_OS = 27,
            OP_CMP_NEQ_OS = 28,
            OP_CMP_GE_OQ = 29,
            OP_CMP_GT_OQ = 30,
            OP_CMP_TRUE_US = 31
        };

        template <compare_operation_imm8 _imm8>
        static f32x8_p compare(f32x8_p a, f32x8_p b) { return f32x8_p(_mm256_cmp_ps(a.data, b.data, _imm8)); }

        __m256i convert_int32p(f32x8_p a) { return _mm256_cvtps_epi32(a.data); }
        __m256i convert_int32p_trunc(f32x8_p a) { return _mm256_cvttps_epi32(a.data); }

        float get_float() const { return _mm256_cvtss_f32(data); }

        f32x4_p get_low() const { return f32x4_p(_mm256_extractf128_ps(data, 0)); }
        f32x4_p get_high() const { return f32x4_p(_mm256_extractf128_ps(data, 1)); }

        static void zeroall() { _mm256_zeroall(); }
        static void zeroupper() { _mm256_zeroupper(); }

        void load_broadcast(float *mem_addr) { data = _mm256_broadcast_ss(mem_addr); }

        void move_high_dup() { data = _mm256_movehdup_ps(data); }
        void move_odd2even() { data = _mm256_movehdup_ps(data); }
        f32x8_p copy_odd2even() { return f32x8_p(_mm256_movehdup_ps(data)); }
        void move_low_dup() { data = _mm256_moveldup_ps(data); }
        void move_even2odd() { data = _mm256_moveldup_ps(data); }
        f32x8_p copy_even2odd() { return f32x8_p(_mm256_moveldup_ps(data)); }

        void rcp() { data = _mm256_rcp_ps(data); }
        void rsqrt() { data = _mm256_rsqrt_ps(data); }
        void sqrt() { data = _mm256_sqrt_ps(data); }

        template <uint8_t _rounding_imm>
        void round() { data = _mm256_round_ps(data, _rounding_imm); }

        // zf: if the (a & b)sign are all zero, return 1, else return 0
        // means no sign bit pair is both 1
        int test_zf() { return _mm256_testz_ps(data, data); } // means all 0
        int test_zf(f32x8_p a) { return _mm256_testz_ps(data, a.data); }
        int test_if_not_both_bit1(f32x8_p a) { return _mm256_testz_ps(data, a.data); }
        // cf: if the (~a & b)sign are all zero, return 1, else return 0
        // means no sign bit pair is a 0 and b 1 (each b1 in after a1)
        int test_cf() { return _mm256_testc_ps(data, data); } // always 1
        int test_cf(f32x8_p a) { return _mm256_testc_ps(data, a.data); }
        int test_if_bit_contained(f32x8_p a) { return _mm256_testc_ps(data, a.data); }

        // (!cf) && (!zf) , only when both cf and zf are 0, return 1
        // means some sign bit pair is both 1 and some sign bit pair is a 0 and b 1
        // also means there are: >=1x a1b1 and >=1x a0b1
        int test_ncz() { return _mm256_testnzc_ps(data, data); } // always 0
        int test_ncz(f32x8_p a) { return _mm256_testnzc_ps(data, a.data); }

        int get_mask() const { return _mm256_movemask_ps(data); }
        int get_signs() const { return _mm256_movemask_ps(data); }
    };

        // 4 complex float32x2 :{c[0].re,c[0].im,c[1].re,c[1].im,c[2].re,c[2].im,c[3].re,c[3].im}
        struct fc32x4_p : public f32x8_p
        {
        public:
            fc32x4_p() : f32x8_p() {}
            fc32x4_p(__m256 data) : f32x8_p(data) {}
            fc32x4_p(f32x8_p a) : f32x8_p(a) {}
            fc32x4_p(float a, float b, float c, float d, float e = 0.0f, float f = 0.0f, float g = 0.0f, float h = 0.0f) : f32x8_p(a, b, c, d, e, f, g, h) {}
            fc32x4_p(float a) : f32x8_p(a) {}
            fc32x4_p(float *a) : f32x8_p(a) {}

            // operator +, - is the same as f32x8_p

            // complex mul complex, use hadd and hsub
            static fc32x4_p multiply_complex_v0(f32x8_p a_bI, f32x8_p c_dI)
            { // real: a*c - b*d, imag: a*d + b*c
                f32x8_p ac_bdI = a_bI * c_dI;
                f32x8_p ad_bcI = a_bI * c_dI.reordered<0b10'11'00'01>();
                ac_bdI.hsub(ac_bdI);
                ad_bcI.hadd(ad_bcI);
                return fc32x4_p(ac_bdI
                                    .remixed<0b11'10'01'00>(ad_bcI) // {r0,r1,i0,i1}
                                    .reordered<0b11'01'10'00>());   // {r0,i0,r1,i1}
            }
            fc32x4_p cmul_v0(fc32x4_p a) { return multiply_complex_v0(data, a.data); }

            // complex mul complex, use addsub, cost about 90% time of _v0
            static fc32x4_p multiply_complex_v1(f32x8_p a_bI, f32x8_p c_dI)
            { // real: a*c - b*d, imag: a*d + b*c
                f32x8_p a_aI = a_bI.copy_even2odd();
                f32x8_p b_bI = a_bI.copy_odd2even();
                f32x8_p ac_adI = a_aI * c_dI;
                f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();
                return fc32x4_p(fc32x4_p::addsub(ac_adI, bd_bcI));
            }
            fc32x4_p cmul_v1(fc32x4_p a) { return multiply_complex_v1(data, a.data); }

            // complex mul complex, use mul_addsub, cost about 98% time of _v1
            static fc32x4_p multiply_complex_v2(f32x8_p a_bI, f32x8_p c_dI)
            { // real: a*c - b*d, imag: a*d + b*c
                f32x8_p a_aI = a_bI.copy_even2odd();
                f32x8_p b_bI = a_bI.copy_odd2even();
                //f32x8_p ac_adI = a_aI * c_dI;
                f32x8_p bd_bcI = b_bI * c_dI.reordered<0b10'11'00'01>();
                return fc32x4_p::mul_addsub(a_aI, c_dI, bd_bcI);
            }
            fc32x4_p cmul_v2(fc32x4_p a) { return multiply_complex_v2(data, a.data); }
            // complex mul complex, use mul_addsub, 
            static fc32x4_p multiply_complex_v3(f32x8_p a_bI_, f32x8_p c_dI)
            { // real: a*c - b*d, imag: a*d + b*c
                f32x8_p a_bI = a_bI_.data;
                //f32x8_p a_aI = a_bI.copy_even2odd();
                //f32x8_p b_bI = a_bI.copy_odd2even();
                //f32x8_p ac_adI = a_aI * c_dI;
                //f32x8_p bd_bcI = a_bI.copy_odd2even() * c_dI.reordered<0b10'11'00'01>();
                return fc32x4_p::mul_addsub(a_bI.copy_even2odd(), c_dI, a_bI.copy_odd2even() * c_dI.reordered<0b10'11'00'01>());
            }
            fc32x4_p cmul_v3(fc32x4_p a) { return multiply_complex_v2(data, a.data); }
        };
        EXTERN_C
        {
        _declspec(dllexport) auto __vectorcall multiply_complex_v0(__m256 a, __m256 b)
        {
            return fc32x4_p::multiply_complex_v0(a, b).data;
        }
        _declspec(dllexport) auto __vectorcall multiply_complex_v1(__m256 a, __m256 b)
        {
            return fc32x4_p::multiply_complex_v1(a, b).data;
        }

        _declspec(dllexport) auto __vectorcall multiply_complex_v2(__m256 a, __m256 b)
        {
            return fc32x4_p::multiply_complex_v2(a, b).data;
        }
        _declspec(dllexport) auto __vectorcall multiply_complex_v3(__m256 a, __m256 b)
        {
            return fc32x4_p::multiply_complex_v3(a, b).data;
    };
    };

#ifdef namespace_bionukg
}
#endif

#endif
  • 34
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值