本文参考github
上文是二维CNN的C语言实现,我参考上文,照葫芦画瓢实现了一维CNN的前项传播,因嵌入式需求去掉了原文中的malloc函数,取而代之的是申请两个静态数组,input,output, input和output来回挪腾完成模型的前向运算,避免了动态的频繁申请和释放空间。静态数组的大小必须能容下前向运算过程中最大的shape的输入或输出。
//卷积模块
int conv(const double* weight, double* input, double* out, int input_size, int in_channels, int out_channels,
int kernel_size, int stride, int pad, int activation)
{
// data preprocess: padding and flatten
int out_size = (input_size + 2 * pad - kernel_size) / stride + 1;
int pad_size = input_size + 2 * pad;
if (pad != 0) {
int index, channel, a = 0;
for (channel = 0; channel < in_channels; channel++) {
for (index = 0; index < pad_size; index++) {
a++;
out[channel*pad_size + index] = get_point(input, input_size, channel, pad, index);
}
}
}
else {
out = input;
}
memcpy(input, out, pad_size*out_channels * sizeof(double));
memset(out, 0, pad_size*out_channels * sizeof(double));
int index, in_channel, out_channel, kernel_index, out_index = 0;
for (out_channel = 0; out_channel < out_channels;out_channel++) {
for (out_index = 0; out_index < out_size; out_index++) {
for (in_channel = 0; in_channel < in_channels;in_channel++) {
for (kernel_index = 0; kernel_index < kernel_size; kernel_index++) {
out[out_channel*out_size + out_index] += weight[out_channel*in_channels*kernel_size + in_channel*kernel_size + kernel_index] * input[in_channel*pad_size + out_index + kernel_index];
}
}
if (activation == 1) {
// 激活函数可以自己写,我这里用的是relu
out[out_channel*out_size + out_index] = out[out_channel*out_size + out_index] > 0 ? out[out_channel*out_size + out_index] : 0;
}
else {
NULL;
}
}
}
memcpy(input, out, pad_size*out_channels * sizeof(double));
memset(out, 0, pad_size*out_channels * sizeof(double));
return 1;
}
// get_point 作用是协助完成pad
double get_point(double* input, int input_size, int channel, int pad, int index) {
index -= pad;
if (index < 0 || index >= input_size) return 0;
return input[input_size*channel + index];
}
// 池化模块
int maxpool(double* input, double* out, int in_size, int in_channels, int pool_size, int stride) {
int out_size, out_channels;
out_size = (in_size - pool_size) / stride + 1;
out_channels = in_channels;
int index, out_channel, pool_index, out_index = 0;
for (out_channel = 0; out_channel < out_channels; out_channel++) {
for (out_index = 0; out_index < out_size; out_index++) {
double max = 0.0;
for (pool_index = 0; pool_index < pool_size; pool_index++) {
index = stride * out_index + pool_index;
if (input[out_channel*in_size + index] > max) {
max = input[out_channel*in_size + index];
}
}
out[out_channel*out_size + out_index] = max;
}
}
memcpy(input, out, out_size*out_channels * sizeof(double));
memset(out, 0, out_size*out_channels * sizeof(double));
return 1;
}
// 平均池化模块, 如果设置in_size和pool_size 一样,并且stride=1,平均池化就变成了全局平均池化
int average_pool(double* input, double* out, int in_size, int in_channels, int pool_size, int stride) {
int out_size, out_channels;
out_size = (in_size - pool_size) / stride + 1;
out_channels = in_channels;
int index, out_channel, pool_index, out_index = 0;
for (out_channel = 0; out_channel < out_channels; out_channel++) {
for (out_index = 0; out_index < out_size; out_index++) {
double sum = 0.0;
int num = 0;
for (pool_index = 0; pool_index < pool_size; pool_index++) {
index = stride * out_index + pool_index;
sum += input[out_channel*in_size + index];
num++;
}
double mean = sum / num;
out[out_channel*out_size + out_index] = mean;
}
}
memcpy(input, out, out_size*out_channels * sizeof(double));
memset(out, 0, out_size*out_channels * sizeof(double));
return 1;
}
// 全连接模块
int full_connection(const double* weight, double* input, double* out, int in_size, int out_size, int activation) {
//double* out = (double*)calloc(out_size, sizeof(double));
int index, out_index;
for (out_index = 0; out_index < out_size; out_index++) {
out[out_index] = 0.0;
for (index = 0; index < in_size; index++) {
out[out_index] = out[out_index] + input[index] * weight[out_index*in_size + index];
//printf("%f, %d\n", out[out_index], index);
}
//out[out_index] = out[out_index] > 0 ? out[out_index] : 0;
}
memcpy(input, out, out_size * sizeof(double));
memset(out, 0, out_size * sizeof(double));
return 1;
}