深度学习C++代码配套教程(3. 数据文件读取)

导航栏

深度学习C++代码 (位于 Github)
深度学习C++代码配套教程(1. 总述)
深度学习C++代码配套教程(2. 基础数据操作)
深度学习C++代码配套教程(3. 数据文件读取)
深度学习C++代码配套教程(4. ANN 经典神经网络)
深度学习C++代码配套教程(5. CNN 卷积神经网络)


数据文件管理是每种程序设计语言的基础. 不熟悉的时候会遇到一些坑, 过了就没事儿了.

1. 数据格式

数据以文本方式存储, 后缀不一定是 txt. 如下是 irist.txt 的前几行
5.1,3.5,1.4,0.2,0
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
4.6,3.1,1.5,0.2,0
有如下几点要求:

  1. 不需要文件头来说明有多少个对象、属性、类别, 这些通过文件扫描自动提取;
  2. 每行表示一个对象, 对象之间不应有空行;
  3. 属性之间用逗号分隔, 条件属性均识别为 double 类型;
  4. 类别为最后一个属性, 按 int 类型处理;
  5. 不支持缺值.
    为保持程序的简洁性, 不对这些错误进行纠正.

2. MfDataReader 类

成员变量申明如下:

//The number of instances
int numInstances;
//The number of conditions
int numConditions;
//The number of classes
int numClasses;

//For data randomization
MfIntArray* randomArray;

//The whole input
MfDoubleMatrix* wholeX;
//The labels of the whole data
MfIntArray* wholeY;

//The training data
MfDoubleMatrix* trainingX;
//The labels of the training data
MfIntArray* trainingY;
//The testing data
MfDoubleMatrix* testingX;
//The labels of the testing data
MfIntArray* testingY;

这里比较特别的仅有 randomArray. 很多数据有规律存放, 如 iris 的前 50 个是 0 类, 中间 50 个是 1 类, 最后 50 个是 2 类. 如果把前 60% 作为训练集, 其它作为测试集, 就尴尬了. 通过 randomArray 获得随机化的一个下标数组, 再通过它来间址获得前 60% 数据, 就可以克服该问题. 如下是间址使用的例子:
trainingX->setValue(i, j, wholeX->getValue(randomArray->getValue(i), j));
它将 wholeX (整个数据集) 的第 randomArray->getValue(i) 数据拷贝成 trainingX 的第 i 行数据.

方法申明如下:

//Empty constructor, not useful.
MfDataReader();
//Read the data from the given file
MfDataReader(char* paraFilename);
//Destructor
virtual ~MfDataReader();

//Split the data into the training and testing parts according to the given fraction
void splitInTwo(double paraTrainingFraction);
//Split the data according to cross-validation
void crossValidationSplit(int paraNumFolds, int paraFoldIndex);

//The getter
MfDoubleMatrix* getTrainingX();
//The getter
MfIntArray* getTrainingY();
//The getter
MfDoubleMatrix* getTestingX();
//The getter
MfIntArray* getTestingY();
//The getter
MfDoubleMatrix* getWholeX();
//The getter
MfIntArray* getWholeY();

//The random array is stored in the object
void randomize();
//Code unit test
void unitTest();

多数函数都很直接, randomize 仅仅重新随机化 randomArray. 但需要注意, 如果不调用它, randomArray 就为 [0, 1, 2, …], 即未随机化.
以下是 3 个重要的函数.

2.1 构造函数

构造方法将数据读入. 有两遍扫描.
第 1 遍扫描获得 numInstances, numConditions;
第 2 遍将数据读入 wholeX 和 wholeY, 并获得 numClasses.

2.2 splitInTwo 函数

由于有 randomArray 的存在, splitInTwo 就仅需要把前面的数据拷贝到训练数据集中即可. 这在 randomArray 中已经说过了.

2.3 crossValidationSplit 函数

void MfDataReader::crossValidationSplit(int paraNumFolds, int paraFoldIndex)
{
    int tempTestingSize = numInstances / paraNumFolds;
    if (paraFoldIndex < numInstances % paraNumFolds)
    {
        tempTestingSize ++;
    }//Of if
    int tempTrainingSize = numInstances - tempTestingSize;

    //Free space if allocated in the past
    if (trainingX != nullptr)
    {
        free(trainingX);
        free(trainingY);
        free(testingX);
        free(testingY);
    }//Of if
    trainingX = new MfDoubleMatrix(tempTrainingSize, numConditions);
    trainingY = new MfIntArray(tempTrainingSize);
    testingX = new MfDoubleMatrix(tempTestingSize, numConditions);
    testingY = new MfIntArray(tempTestingSize);
    int tempTrainingIndex = 0;
    int tempTestingIndex = 0;

    for(int i = 0; i < numInstances; i ++)
    {
        if (i % paraNumFolds != paraFoldIndex)
        {
            for(int j = 0; j < numConditions; j ++)
            {
                trainingX->setValue(tempTrainingIndex, j, wholeX->getValue(randomArray->getValue(i), j));
            }//Of for j
            trainingY->setValue(tempTrainingIndex, wholeY->getValue(randomArray->getValue(i)));
            tempTrainingIndex ++;
        }
        else
        {
            for(int j = 0; j < numConditions; j ++)
            {
                testingX->setValue(tempTestingIndex, j, wholeX->getValue(randomArray->getValue(i), j));
            }//Of for j
            testingY->setValue(tempTestingIndex, wholeY->getValue(randomArray->getValue(i)));
            tempTestingIndex ++;
        }//Of if
    }//Of for i
}//Of crossValidationSplit

这个函数仅生成一个训练集/测试集对. 如果要进行真正的 cross validation, 需要对它进行 paraNumFolds 次调用, 然后进行相应的训练/测试. 该函数有如下特点:

  1. 代码中取余操作保证了 CV 分割的有效性;
  2. 不需要把所有的训练/测试集对一次性生成, 节约了空间.

3. 小结

如果进行本类的单元测试,程序结束时会报错. 这是内存管理 free 导致的, 见 2.3 节.当前我还没找到解决方案.

点击进入下一节

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值