对代码进行了一下修改https://github.com/aitazhixin/DL/tree/master/AlexNet,目前的运行效率比较低,准备采用多线程的方式实现一下。
根据原文,所有卷积层和池化层采用了ReLU激活函数,激活函数的导数为0或1(激活值非0时)。第三卷积层和第一、第二全连接层使用了Dropout方式。没有采用原文使用的LRN技术,采用了最大池化。在ImageNet2012数据集中选取了10类图片进行训练,但是迭代的效果不好。贴出代码希望集思广益,希望高手拔刀相助。
一些初始化或者辅助接口:
slqAlexNet::~slqAlexNet()
{
if (inRaw)
{
delete inRaw;
inRaw = nullptr;
}
deletevar(&mlabel);
deletevar(&inMap);
deletevar(&c1Map);
deletevar(&s1Map);
deletevar(&c2Map);
deletevar(&s2Map);
deletevar(&c3Map);
deletevar(&c4Map);
deletevar(&c5Map);
deletevar(&s5Map);
deletevar(&f1Map);
deletevar(&f2Map);
deletevar(&f3Map);
deletevar(&c1Conv);
deletevar(&s1Pool);
deletevar(&c2Conv);
deletevar(&s2Pool);
deletevar(&c3Conv);
deletevar(&c4Conv);
deletevar(&c5Conv);
deletevar(&s5Pool);
deletevar(&f1Conn);
deletevar(&f2Conn);
deletevar(&f3Conn);
deletevar(&c1Bias);
deletevar(&s1Bias);
deletevar(&c2Bias);
deletevar(&s2Bias);
deletevar(&c3Bias);
deletevar(&c4Bias);
deletevar(&c5Bias);
deletevar(&s5Bias);
deletevar(&f1Bias);
deletevar(&f2Bias);
deletevar(&f3Bias);
deletevar(&c1MapDt);
deletevar(&s1MapDt);
deletevar(&c2MapDt);
deletevar(&s2MapDt);
deletevar(&c3MapDt);
deletevar(&c4MapDt);
deletevar(&c5MapDt);
deletevar(&s5MapDt);
deletevar(&f1MapDt);
deletevar(&f2MapDt);
deletevar(&f3MapDt);
deletevar(&c1ConvDt);
deletevar(&s1PoolDt);
deletevar(&c2ConvDt);
deletevar(&s2PoolDt);
deletevar(&c3ConvDt);
deletevar(&c4ConvDt);
deletevar(&c5ConvDt);
deletevar(&s5PoolDt);
deletevar(&f1ConnDt);
deletevar(&f2ConnDt);
deletevar(&f3ConnDt);
deletevar(&c1BiasDt);
deletevar(&s1BiasDt);
deletevar(&c2BiasDt);
deletevar(&s2BiasDt);
deletevar(&c3BiasDt);
deletevar(&c4BiasDt);
deletevar(&c5BiasDt);
deletevar(&s5BiasDt);
deletevar(&f1BiasDt);
deletevar(&f2BiasDt);
deletevar(&f3BiasDt);
deletevar(&c1ConvEDt);
deletevar(&s1PoolEDt);
deletevar(&c2ConvEDt);
deletevar(&s2PoolEDt);
deletevar(&c3ConvEDt);
deletevar(&c4ConvEDt);
deletevar(&c5ConvEDt);
deletevar(&s5PoolEDt);
deletevar(&f1ConnEDt);
deletevar(&f2ConnEDt);
deletevar(&f3ConnEDt);
deletevar(&c1BiasEDt);
deletevar(&s1BiasEDt);
deletevar(&c2BiasEDt);
deletevar(&s2BiasEDt);
deletevar(&c3BiasEDt);
deletevar(&c4BiasEDt);
deletevar(&c5BiasEDt);
deletevar(&s5BiasEDt);
deletevar(&f1BiasEDt);
deletevar(&f2BiasEDt);
deletevar(&f3BiasEDt);
}
void slqAlexNet::init()
{
initParm();
}
void slqAlexNet::CreateConv3Table()
{
int oIdx;
int iIdx;
#define AO true
#define AX false
for (oIdx = 0; oIdx < c3ConvNum; oIdx++)
{
for (iIdx = 0; iIdx < c3ConvDeep; iIdx++)
{
if ((oIdx >= iIdx) && (oIdx < iIdx + c3ConvDeep/2))
{
CONV3Table[oIdx][iIdx] = AX;
continue;
}
CONV3Table[oIdx][iIdx] = AO;
}
}
for (oIdx = 0; oIdx < s5UnitNum; oIdx++)
{
for (iIdx = 0; iIdx < f1UnitNum; iIdx++)
{
if ((oIdx >= iIdx) && (oIdx < iIdx + f1UnitNum/2))
{
F1Table[oIdx][iIdx] = AX;
continue;
}
F1Table[oIdx][iIdx] = AO;
}
}
for (oIdx = 0; oIdx < f1UnitNum; oIdx++)
{
for (iIdx = 0; iIdx < f2UnitNum; iIdx++)
{
if ((oIdx >= iIdx) && (oIdx < iIdx + f2UnitNum/2))
{
F2Table[oIdx][iIdx] = AX;
continue;
}
F2Table[oIdx][iIdx] = AO;
}
}
#undef AO
#undef AX
}
void slqAlexNet::deletevar(float **var)
{
if (nullptr != *var)
{
delete *var;
*var = nullptr;
}
}
void slqAlexNet::initParm()
{
newParam();
// all weights follow gauss distribution
uniform_rand(c1Conv, c1ConvUNum, 0.f, 0.f);
uniform_rand(s1Pool, s1MapNum, 0.f, 0.f);
uniform_rand(c2Conv, c2ConvUNum, 0.f, 0.f);
uniform_rand(s2Pool, s2MapNum, 0.f, 0.f);
uniform_rand(c3Conv, c3ConvUNum, 0.f, 0.f);
uniform_rand(c4Conv, c4ConvUNum, 0.f, 0.f);
uniform_rand(c5Conv, c5ConvUNum, 0.f, 0.f);
uniform_rand(s5Pool, s5MapNum, 0.f, 0.f);
uniform_rand(f1Conn, f1ConnNum, 0.f, 0.f);
uniform_rand(f2Conn, f2ConnNum, 0.f, 0.f);
uniform_rand(f3Conn, f3ConnNum, 0.f, 0.f);
// 2th, 4th, 5th conv bias set to 1
std::fill(c2Bias, c2Bias + c2MapNum, 1.0f);
std::fill(c4Bias, c4Bias + c4MapNum, 1.0f);
std::fill(c5Bias, c5Bias + c5MapNum, 1.0f);
}
void slqAlexNet::newParam()
{
mlabel = new float[f3UnitNum]();
inRaw = new char[inUnitNum]();
inMap = new float[inUnitNum]();
c1Map = new float[c1UnitNum]();
s1Map = new float[s1UnitNum]();
c2Map = new float[c2UnitNum]();
s2Map = new float[s2UnitNum]();
c3Map = new float[c3UnitNum]();
c4Map = new float[c4UnitNum]();
c5Map = new float[c5UnitNum]();
s5Map = new float[s5UnitNum]();
f1Map = new float[f1UnitNum]();
f2Map = new float[f2UnitNum]();
f3Map = new float[f3UnitNum]();
c1Conv = new float[c1ConvUNum]();
s1Pool = new float[s1MapNum]();
c2Conv = new float[c2ConvUNum]();
s2Pool = new float[s2MapNum]();
c3Conv = new float[c3ConvUNum]();
c4Conv = new float[c4ConvUNum]();
c5Conv = new float[c5ConvUNum]();
s5Pool = new float[s5MapNum]();
f1Conn = new float[f1ConnNum]();
f2Conn = new float[f2ConnNum]();
f3Conn = new float[f3ConnNum]();
c1Bias = new float[c1MapNum]();
s1Bias = new float[s1MapNum]();
c2Bias = new float[c2MapNum]();
s2Bias = new float[s2MapNum]();
c3Bias = new float[c3MapNum]();
c4Bias = new float[c4MapNum]();
c5Bias = new float[c5MapNum]();
s5Bias = new float[s5MapNum]();
f1Bias = new float[f1UnitNum]();
f2Bias = new float[f2UnitNum]();
f3Bias = new float[f3UnitNum]();
c1MapDt = new float[c1UnitNum]();
s1MapDt = new float[s1UnitNum]();
c2MapDt = new float[c2UnitNum]();
s2MapDt = new float[s2UnitNum]();
c3MapDt = new float[c3UnitNum]();
c4MapDt = new float[c4UnitNum]();
c5MapDt = new float[c5UnitNum]();
s5MapDt = new float[s5UnitNum]();
f1MapDt = new float[f1UnitNum]();
f2MapDt = new float[f2UnitNum]();
f3MapDt = new float[f3UnitNum]();
c1ConvDt = new float[c1ConvUNum]();
s1PoolDt = new float[s1MapNum]();
c2ConvDt = new float[c2ConvUNum]();
s2PoolDt = new float[s2MapNum]();
c3ConvDt = new float[c3ConvUNum]();
c4ConvDt = new float[c4ConvUNum]();
c5ConvDt = new float[c5ConvUNum]();
s5PoolDt = new float[s5MapNum]();
f1ConnDt = new float[f1ConnNum]();
f2ConnDt = new float[f2ConnNum]();
f3ConnDt = new float[f3ConnNum]();
c1BiasDt = new float[c1MapNum]();
s1BiasDt = new float[s1MapNum]();
c2BiasDt = new float[c2MapNum]();
s2BiasDt = new float[s2MapNum]();
c3BiasDt = new float[c3MapNum]();
c4BiasDt = new float[c4MapNum]();
c5BiasDt = new float[c5MapNum]();
s5BiasDt = new float[s5MapNum]();
f1BiasDt = new float[f1UnitNum]();
f2BiasDt = new float[f2UnitNum]();
f3BiasDt = new float[f3UnitNum]();
c1ConvEDt = new float[c1ConvUNum]();
s1PoolEDt = new float[s1MapNum]();
c2ConvEDt = new float[c2ConvUNum]();
s2PoolEDt = new float[s2MapNum]();
c3ConvEDt = new float[c3ConvUNum]();
c4ConvEDt = new float[c4ConvUNum]();
c5ConvEDt = new float[c5ConvUNum]();
s5PoolEDt = new float[s5MapNum]();
f1ConnEDt = new float[f1ConnNum]();
f2ConnEDt = new float[f2ConnNum]();
f3ConnEDt = new float[f3ConnNum]();
c1BiasEDt = new float[c1MapNum]();
s1BiasEDt = new float[s1MapNum]();
c2BiasEDt = new float[c2MapNum]();
s2BiasEDt = new float[s2MapNum]();
c3BiasEDt = new float[c3MapNum]();
c4BiasEDt = new float[c4MapNum]();
c5BiasEDt = new float[c5MapNum]();
s5BiasEDt = new float[s5MapNum]();
f1BiasEDt = new float[f1UnitNum]();
f2BiasEDt = new float[f2UnitNum]();
f3BiasEDt = new float[f3UnitNum]();
}
void slqAlexNet::UpgradeNetwork()
{
UpdateParameters(c1ConvDt, c1ConvEDt, c1Conv, c1ConvUNum);
UpdateParameters(c1BiasDt, c1BiasEDt, c1Bias, c1MapNum);
UpdateParameters(s1PoolDt, s1PoolEDt, s1Pool, s1MapNum);
UpdateParameters(s1BiasDt, s1BiasEDt, s1Bias, s1MapNum);
UpdateParameters(c2ConvDt, c2ConvEDt, c2Conv, c2ConvUNum);
UpdateParameters(c2BiasDt, c2BiasEDt, c2Bias, c2MapNum);
UpdateParameters(s2PoolDt, s2PoolEDt, s2Pool, s2MapNum);
UpdateParameters(s2BiasDt, s2BiasEDt, s2Bias, s2MapNum);
UpdateParameters(c3ConvDt, c3ConvEDt, c3Conv, c3ConvUNum);
UpdateParameters(c3BiasDt, c3BiasEDt, c3Bias, c3MapNum);
UpdateParameters(c4ConvDt, c4ConvEDt, c4Conv, c4ConvUNum);
UpdateParameters(c4BiasDt, c4BiasEDt, c4Bias, c4MapNum);
UpdateParameters(c5ConvDt, c5ConvEDt, c5Conv, c5ConvUNum);
UpdateParameters(c5BiasDt, c5BiasEDt, c5Bias, c5MapNum);
UpdateParameters(s5PoolDt, s5PoolEDt, s5Pool, s5MapNum);
UpdateParameters(s5BiasDt, s5BiasEDt, s5Bias, s5MapNum);
UpdateParameters(f1ConnDt, f1ConnEDt, f1Conn, f1ConnNum);
UpdateParameters(f1BiasDt, f1BiasEDt, f1Bias, f1UnitNum);
UpdateParameters(f2ConnDt, f2ConnEDt, f2Conn, f2ConnNum);
UpdateParameters(f2BiasDt, f2BiasEDt, f2Bias, f2UnitNum);
UpdateParameters(f3ConnDt, f3ConnEDt, f3Conn, f3ConnNum);
UpdateParameters(f3BiasDt, f3BiasEDt, f3Bias, f3UnitNum);
}
void slqAlexNet::UpdateParameters(float *delta, float *Edelta, float *para, int len)
{
for (int lIdx = 0; lIdx < len; lIdx++)
{
Edelta[lIdx] = 0.9f * Edelta[lIdx] - 0.0005 * Alpha * delta[lIdx];
para[lIdx] += Edelta[lIdx];
}
}
void slqAlexNet::ConvolutionOpt(float *inPtr, float *outPtr, float *convPtr, float *biPtr, int param[])
{
int inMapNo = param[0];
int outMapNo = param[1];
int inMapH = param[2];
int inMapW = param[3];
int outMapH = param[4];
int outMapW = param[5];
int convH = param[6];
int convW = param[7];
int convStride = param[8];
int expand = param[9];
int odeepIdx;
int ideepIdx;
int ohIdx;
int owIdx;
int phIdx;
int pwIdx;
int iod;
int vod;
int svh;
int insize = inMapH * inMapW;
int convsize = inMapNo * convH * convW;
int convtsor = convH * convW;
for (odeepIdx = 0; odeepIdx < outMapNo; odeepIdx++)
{
for (ohIdx = expand; ohIdx < outMapH - expand; ohIdx++)
{
iod = odeepIdx * outMapH * outMapW + ohIdx * outMapW;
for (owIdx = expand; owIdx < outMapW - expand; owIdx++)
{
float *cu