以下是一个简单的 LeNet 的 C 代码实现:
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define WIDTH 28
#define HEIGHT 28
#define NUM_CLASSES 10
#define NUM_TRAIN 60000
#define NUM_TEST 10000
typedef struct conv_layer {
int stride;
int num_filters;
int filter_size;
int input_size;
int output_size;
double ***filters;
double **biases;
double **output;
} conv_layer;
typedef struct pool_layer {
int stride;
int pool_size;
int input_size;
int output_size;
double **output;
} pool_layer;
typedef struct dense_layer {
int num_neurons;
double *weights;
double *biases;
double *output;
} dense_layer;
int reverse_int(int i) {
unsigned char ch1, ch2, ch3, ch4;
ch1 = i & 255;
ch2 = (i >> 8) & 255;
ch3 = (i >> 16) & 255;
ch4 = (i >> 24) & 255;
return ((int) ch1 << 24) + ((int) ch2 << 16) + ((int) ch3 << 8) + ch4;
}
void read_images(double **images, char *filename, int num_images) {
FILE *file = fopen(filename, "rb");
int magic_number = 0;
int num_images_read = 0;
fread(&magic_number, sizeof(magic_number), 1, file);
magic_number = reverse_int(magic_number);
if (magic_number != 2051) {
printf("Error: Invalid image file format\n");
exit(1);
}
fread(&num_images_read, sizeof(num_images_read), 1, file);
num_images_read = reverse_int(num_images_read);
if (num_images_read != num_images) {
printf("Error: Invalid number of images\n");
exit(1);
}
int num_rows = 0, num_cols = 0;
fread(&num_rows, sizeof(num_rows), 1, file);
fread(&num_cols, sizeof(num_cols), 1, file);
num_rows = reverse_int(num_rows);
num_cols = reverse_int(num_cols);
int i, j, k;
for (i = 0; i < num_images; i++) {
for (j = 0; j < num_rows; j++) {
for (k = 0; k < num_cols; k++) {
unsigned char pixel = 0;
fread(&pixel, sizeof(pixel), 1, file);
images[i][(j * num_cols) + k] = (double) pixel;
}
}
}
fclose(file);
}
void read_labels(double *labels, char *filename, int num_labels) {
FILE *file = fopen(filename, "rb");
int magic_number = 0;
int num_labels_read = 0;
fread(&magic_number, sizeof(magic_number), 1, file);
magic_number = reverse_int(magic_number);
if (magic_number != 2049) {
printf("Error: Invalid label file format\n");
exit(1);
}
fread(&num_labels_read, sizeof(num_labels_read), 1, file);
num_labels_read = reverse_int(num_labels_read);
if (num_labels_read != num_labels) {
printf("Error: Invalid number of labels\n");
exit(1);
}
int i;