armv8, dup
code
#include<stdio.h>
#include<stdlib.h>
#include<time.h>
#include<arm_neon.h>
#include<math.h>
double get_current_time()
{
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0;
}
void batch_c(float* src, float* out, int count, float u, float std, float w, float b)
{
for (size_t i = 0; i < count; i++)
{
// out[i] = w * ((src[i] - u) / std) + b;
out[i] = w * src[i] + b;
}
}
void batch_assembly(float* src, float* out, int count, float u, float std, float w, float b)
{
asm volatile(
"dup v2.4s, %w4 \n" // 对比w3
"dup v3.4s, %w5 \n" // b
"1: \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v0.4s}, [%1], #16 \n"
"orr v1.16b, v3.16b, v3.16b \n"
"fmla v1.4s, v0.4s, v2.4s \n"
// "fmul v0.4s, v0.4s, v2.4s \n"
// "fadd v0.4s, v0.4s, v3.4s \n"
"subs %3, %3, #4 \n"
"st1 {v1.4s}, [%0], #16 \n"
"bgt 1b \n "
:"=r"(out) // 0, x0
:"r"(src), // 1, x1
"0"(out), //
"r"(count), // 3, w2
"r"(w),
"r"(b)
:"cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5"
);
}
int main(void){
int num_ = 12;
int loop = 2;
double start, end, cur;
float u = 0.1; // 均值
float std = 0.2; // 方差
float w = 0.3; // 权重
float b = 0.4; // 偏值
float* src_a = (float*)malloc(sizeof(float) * num_);
float* src_b = (float*)malloc(sizeof(float) * num_);
for (size_t i = 0; i < num_; i++)
{
src_a[i] = (rand() / (RAND_MAX + 1.0)) * 2 - 1; // 0~1之间
}
// warm up
for (size_t i = 0; i < 10; i++)
batch_c(src_a, src_b, num_, u, std, w, b);
// test for c
start = get_current_time();
for (size_t i = 0; i < loop; i++)
batch_c(src_a, src_b, num_, u, std, w, b);
end = get_current_time();
cur = (end - start) / loop;
printf("c test:%f | time:%f ms \n", 0., cur);
// debug show
for (size_t i = 0; i < num_; i++)
printf("%f ", src_a[i]);
printf("\n");
for (size_t i = 0; i < num_; i++)
printf("%f ", src_b[i]);
printf("\n");
printf("u:%f, std:%f, w:%f, b:%f \n", u, std, w, b);
// test for neon assembly
start = get_current_time();
for (size_t i = 0; i < loop; i++)
batch_assembly(src_a, src_b, num_, u, std, w, b);
end = get_current_time();
cur = (end - start) / loop;
printf("assembly:%f | time:%f ms \n", 2., cur);
// debug show
for (size_t i = 0; i < num_; i++)
printf("%f ", src_a[i]);
printf("\n");
for (size_t i = 0; i < num_; i++)
printf("%f ", src_b[i]);
printf("\n");
printf("u:%f, std:%f, w:%f, b:%f \n", u, std, w, b);
free(src_a);
free(src_b);
return 0;
}
输出
c test:0.000000 | time:0.000488 ms
0.680375 -0.211234 0.566198 0.596880 0.823295 -0.604897 -0.329554 0.536459 -0.444451 0.107940 -0.045206 0.257742
0.604113 0.336630 0.569860 0.579064 0.646988 0.218531 0.301134 0.560938 0.266665 0.432382 0.386438 0.477323
u:0.100000, std:0.200000, w:0.300000, b:0.400000
assembly:2.000000 | time:0.000488 ms
0.680375 -0.211234 0.566198 0.596880 0.823295 -0.604897 -0.329554 0.536459 -0.444451 0.107940 -0.045206 0.257742
0.604113 0.336630 0.569860 0.579064 0.646988 0.218531 0.301134 0.560938 0.266665 0.432382 0.386438 0.477323
u:0.100000, std:0.200000, w:0.300000, b:0.400000
主要注意 dup指令的使用
参考