AlexNet模型实现(3. C++模型实现)

对代码进行了一下修改https://github.com/aitazhixin/DL/tree/master/AlexNet,目前的运行效率比较低,准备采用多线程的方式实现一下。

根据原文,所有卷积层和池化层采用了ReLU激活函数,激活函数的导数为0或1(激活值非0时)。第三卷积层和第一、第二全连接层使用了Dropout方式。没有采用原文使用的LRN技术,采用了最大池化。在ImageNet2012数据集中选取了10类图片进行训练,但是迭代的效果不好。贴出代码希望集思广益,希望高手拔刀相助。

一些初始化或者辅助接口:

slqAlexNet::~slqAlexNet()
        {
            if (inRaw)
            {
                delete inRaw;
                inRaw = nullptr;
            }
            deletevar(&mlabel);
            deletevar(&inMap);
            deletevar(&c1Map);
            deletevar(&s1Map);
            deletevar(&c2Map);
            deletevar(&s2Map);
            deletevar(&c3Map);
            deletevar(&c4Map);
            deletevar(&c5Map);
            deletevar(&s5Map);
            deletevar(&f1Map);
            deletevar(&f2Map);
            deletevar(&f3Map);

            deletevar(&c1Conv);
            deletevar(&s1Pool);
            deletevar(&c2Conv);
            deletevar(&s2Pool);
            deletevar(&c3Conv);
            deletevar(&c4Conv);
            deletevar(&c5Conv);
            deletevar(&s5Pool);
            deletevar(&f1Conn);
            deletevar(&f2Conn);
            deletevar(&f3Conn);

            deletevar(&c1Bias);
            deletevar(&s1Bias);
            deletevar(&c2Bias);
            deletevar(&s2Bias);
            deletevar(&c3Bias);
            deletevar(&c4Bias);
            deletevar(&c5Bias);
            deletevar(&s5Bias);
            deletevar(&f1Bias);
            deletevar(&f2Bias);
            deletevar(&f3Bias);

            deletevar(&c1MapDt);
            deletevar(&s1MapDt);
            deletevar(&c2MapDt);
            deletevar(&s2MapDt);
            deletevar(&c3MapDt);
            deletevar(&c4MapDt);
            deletevar(&c5MapDt);
            deletevar(&s5MapDt);
            deletevar(&f1MapDt);
            deletevar(&f2MapDt);
            deletevar(&f3MapDt);

            deletevar(&c1ConvDt);
            deletevar(&s1PoolDt);
            deletevar(&c2ConvDt);
            deletevar(&s2PoolDt);
            deletevar(&c3ConvDt);
            deletevar(&c4ConvDt);
            deletevar(&c5ConvDt);
            deletevar(&s5PoolDt);
            deletevar(&f1ConnDt);
            deletevar(&f2ConnDt);
            deletevar(&f3ConnDt);

            deletevar(&c1BiasDt);
            deletevar(&s1BiasDt);
            deletevar(&c2BiasDt);
            deletevar(&s2BiasDt);
            deletevar(&c3BiasDt);
            deletevar(&c4BiasDt);
            deletevar(&c5BiasDt);
            deletevar(&s5BiasDt);
            deletevar(&f1BiasDt);
            deletevar(&f2BiasDt);
            deletevar(&f3BiasDt);

            deletevar(&c1ConvEDt);
            deletevar(&s1PoolEDt);
            deletevar(&c2ConvEDt);
            deletevar(&s2PoolEDt);
            deletevar(&c3ConvEDt);
            deletevar(&c4ConvEDt);
            deletevar(&c5ConvEDt);
            deletevar(&s5PoolEDt);
            deletevar(&f1ConnEDt);
            deletevar(&f2ConnEDt);
            deletevar(&f3ConnEDt);

            deletevar(&c1BiasEDt);
            deletevar(&s1BiasEDt);
            deletevar(&c2BiasEDt);
            deletevar(&s2BiasEDt);
            deletevar(&c3BiasEDt);
            deletevar(&c4BiasEDt);
            deletevar(&c5BiasEDt);
            deletevar(&s5BiasEDt);
            deletevar(&f1BiasEDt);
            deletevar(&f2BiasEDt);
            deletevar(&f3BiasEDt);
        }

        void slqAlexNet::init()
        {
            initParm();
        }

        void slqAlexNet::CreateConv3Table()
        {
            int oIdx;
            int iIdx;
#define AO true
#define AX false
            for (oIdx = 0; oIdx < c3ConvNum; oIdx++)
            {
                for (iIdx = 0; iIdx < c3ConvDeep; iIdx++)
                {
                    if ((oIdx >= iIdx) && (oIdx < iIdx + c3ConvDeep/2))
                    {
                        CONV3Table[oIdx][iIdx] = AX;
                        continue;
                    }
                    CONV3Table[oIdx][iIdx] = AO;
                }
            }

            for (oIdx = 0; oIdx < s5UnitNum; oIdx++)
            {
                for (iIdx = 0; iIdx < f1UnitNum; iIdx++)
                {
                    if ((oIdx >= iIdx) && (oIdx < iIdx + f1UnitNum/2))
                    {
                        F1Table[oIdx][iIdx] = AX;
                        continue;
                    }
                    F1Table[oIdx][iIdx] = AO;
                }
            }

            for (oIdx = 0; oIdx < f1UnitNum; oIdx++)
            {
                for (iIdx = 0; iIdx < f2UnitNum; iIdx++)
                {
                    if ((oIdx >= iIdx) && (oIdx < iIdx + f2UnitNum/2))
                    {
                        F2Table[oIdx][iIdx] = AX;
                        continue;
                    }
                    F2Table[oIdx][iIdx] = AO;
                }
            }
#undef AO
#undef AX
        }


        void slqAlexNet::deletevar(float **var)
        {
            if (nullptr != *var)
            {
                delete *var;
                *var = nullptr;
            }
        }


        void slqAlexNet::initParm()
        {
            newParam();

            // all weights follow gauss distribution
            uniform_rand(c1Conv, c1ConvUNum, 0.f, 0.f);
            uniform_rand(s1Pool, s1MapNum, 0.f, 0.f);
            uniform_rand(c2Conv, c2ConvUNum, 0.f, 0.f);
            uniform_rand(s2Pool, s2MapNum, 0.f, 0.f);
            uniform_rand(c3Conv, c3ConvUNum, 0.f, 0.f);
            uniform_rand(c4Conv, c4ConvUNum, 0.f, 0.f);
            uniform_rand(c5Conv, c5ConvUNum, 0.f, 0.f);
            uniform_rand(s5Pool, s5MapNum, 0.f, 0.f);
            uniform_rand(f1Conn, f1ConnNum, 0.f, 0.f);
            uniform_rand(f2Conn, f2ConnNum, 0.f, 0.f);
            uniform_rand(f3Conn, f3ConnNum, 0.f, 0.f);

            // 2th, 4th, 5th conv bias set to 1
            std::fill(c2Bias, c2Bias + c2MapNum, 1.0f);
            std::fill(c4Bias, c4Bias + c4MapNum, 1.0f);
            std::fill(c5Bias, c5Bias + c5MapNum, 1.0f);

        }


        void slqAlexNet::newParam()
        {
            mlabel = new float[f3UnitNum]();
            inRaw = new char[inUnitNum]();
            inMap = new float[inUnitNum]();
            c1Map = new float[c1UnitNum]();
            s1Map = new float[s1UnitNum]();
            c2Map = new float[c2UnitNum]();
            s2Map = new float[s2UnitNum]();
            c3Map = new float[c3UnitNum]();
            c4Map = new float[c4UnitNum]();
            c5Map = new float[c5UnitNum]();
            s5Map = new float[s5UnitNum]();
            f1Map = new float[f1UnitNum]();
            f2Map = new float[f2UnitNum]();
            f3Map = new float[f3UnitNum]();

            c1Conv = new float[c1ConvUNum]();
            s1Pool = new float[s1MapNum]();
            c2Conv = new float[c2ConvUNum]();
            s2Pool = new float[s2MapNum]();
            c3Conv = new float[c3ConvUNum]();
            c4Conv = new float[c4ConvUNum]();
            c5Conv = new float[c5ConvUNum]();
            s5Pool = new float[s5MapNum]();
            f1Conn = new float[f1ConnNum]();
            f2Conn = new float[f2ConnNum]();
            f3Conn = new float[f3ConnNum]();

            c1Bias = new float[c1MapNum]();
            s1Bias = new float[s1MapNum]();
            c2Bias = new float[c2MapNum]();
            s2Bias = new float[s2MapNum]();
            c3Bias = new float[c3MapNum]();
            c4Bias = new float[c4MapNum]();
            c5Bias = new float[c5MapNum]();
            s5Bias = new float[s5MapNum]();
            f1Bias = new float[f1UnitNum]();
            f2Bias = new float[f2UnitNum]();
            f3Bias = new float[f3UnitNum]();

            c1MapDt = new float[c1UnitNum]();
            s1MapDt = new float[s1UnitNum]();
            c2MapDt = new float[c2UnitNum]();
            s2MapDt = new float[s2UnitNum]();
            c3MapDt = new float[c3UnitNum]();
            c4MapDt = new float[c4UnitNum]();
            c5MapDt = new float[c5UnitNum]();
            s5MapDt = new float[s5UnitNum]();
            f1MapDt = new float[f1UnitNum]();
            f2MapDt = new float[f2UnitNum]();
            f3MapDt = new float[f3UnitNum]();

            c1ConvDt = new float[c1ConvUNum]();
            s1PoolDt = new float[s1MapNum]();
            c2ConvDt = new float[c2ConvUNum]();
            s2PoolDt = new float[s2MapNum]();
            c3ConvDt = new float[c3ConvUNum]();
            c4ConvDt = new float[c4ConvUNum]();
            c5ConvDt = new float[c5ConvUNum]();
            s5PoolDt = new float[s5MapNum]();
            f1ConnDt = new float[f1ConnNum]();
            f2ConnDt = new float[f2ConnNum]();
            f3ConnDt = new float[f3ConnNum]();


            c1BiasDt = new float[c1MapNum]();
            s1BiasDt = new float[s1MapNum]();
            c2BiasDt = new float[c2MapNum]();
            s2BiasDt = new float[s2MapNum]();
            c3BiasDt = new float[c3MapNum]();
            c4BiasDt = new float[c4MapNum]();
            c5BiasDt = new float[c5MapNum]();
            s5BiasDt = new float[s5MapNum]();
            f1BiasDt = new float[f1UnitNum]();
            f2BiasDt = new float[f2UnitNum]();
            f3BiasDt = new float[f3UnitNum]();

            c1ConvEDt = new float[c1ConvUNum]();
            s1PoolEDt = new float[s1MapNum]();
            c2ConvEDt = new float[c2ConvUNum]();
            s2PoolEDt = new float[s2MapNum]();
            c3ConvEDt = new float[c3ConvUNum]();
            c4ConvEDt = new float[c4ConvUNum]();
            c5ConvEDt = new float[c5ConvUNum]();
            s5PoolEDt = new float[s5MapNum]();
            f1ConnEDt = new float[f1ConnNum]();
            f2ConnEDt = new float[f2ConnNum]();
            f3ConnEDt = new float[f3ConnNum]();

            c1BiasEDt = new float[c1MapNum]();
            s1BiasEDt = new float[s1MapNum]();
            c2BiasEDt = new float[c2MapNum]();
            s2BiasEDt = new float[s2MapNum]();
            c3BiasEDt = new float[c3MapNum]();
            c4BiasEDt = new float[c4MapNum]();
            c5BiasEDt = new float[c5MapNum]();
            s5BiasEDt = new float[s5MapNum]();
            f1BiasEDt = new float[f1UnitNum]();
            f2BiasEDt = new float[f2UnitNum]();
            f3BiasEDt = new float[f3UnitNum]();
        }

        void slqAlexNet::UpgradeNetwork()
        {
            UpdateParameters(c1ConvDt, c1ConvEDt, c1Conv, c1ConvUNum);
            UpdateParameters(c1BiasDt, c1BiasEDt, c1Bias, c1MapNum);

            UpdateParameters(s1PoolDt, s1PoolEDt, s1Pool, s1MapNum);
            UpdateParameters(s1BiasDt, s1BiasEDt, s1Bias, s1MapNum);

            UpdateParameters(c2ConvDt, c2ConvEDt, c2Conv, c2ConvUNum);
            UpdateParameters(c2BiasDt, c2BiasEDt, c2Bias, c2MapNum);

            UpdateParameters(s2PoolDt, s2PoolEDt, s2Pool, s2MapNum);
            UpdateParameters(s2BiasDt, s2BiasEDt, s2Bias, s2MapNum);

            UpdateParameters(c3ConvDt, c3ConvEDt, c3Conv, c3ConvUNum);
            UpdateParameters(c3BiasDt, c3BiasEDt, c3Bias, c3MapNum);

            UpdateParameters(c4ConvDt, c4ConvEDt, c4Conv, c4ConvUNum);
            UpdateParameters(c4BiasDt, c4BiasEDt, c4Bias, c4MapNum);

            UpdateParameters(c5ConvDt, c5ConvEDt, c5Conv, c5ConvUNum);
            UpdateParameters(c5BiasDt, c5BiasEDt, c5Bias, c5MapNum);

            UpdateParameters(s5PoolDt, s5PoolEDt, s5Pool, s5MapNum);
            UpdateParameters(s5BiasDt, s5BiasEDt, s5Bias, s5MapNum);

            UpdateParameters(f1ConnDt, f1ConnEDt, f1Conn, f1ConnNum);
            UpdateParameters(f1BiasDt, f1BiasEDt, f1Bias, f1UnitNum);

            UpdateParameters(f2ConnDt, f2ConnEDt, f2Conn, f2ConnNum);
            UpdateParameters(f2BiasDt, f2BiasEDt, f2Bias, f2UnitNum);

            UpdateParameters(f3ConnDt, f3ConnEDt, f3Conn, f3ConnNum);
            UpdateParameters(f3BiasDt, f3BiasEDt, f3Bias, f3UnitNum);
        }


        void slqAlexNet::UpdateParameters(float *delta, float *Edelta, float *para, int len)
        {
            for (int lIdx = 0; lIdx < len; lIdx++)
            {
                Edelta[lIdx] = 0.9f * Edelta[lIdx] - 0.0005 * Alpha * delta[lIdx];
                para[lIdx] += Edelta[lIdx];
            }
        }

        void slqAlexNet::ConvolutionOpt(float *inPtr, float *outPtr, float *convPtr, float *biPtr, int param[])
        {
            int inMapNo     = param[0];
            int outMapNo    = param[1];
            int inMapH      = param[2];
            int inMapW      = param[3];
            int outMapH     = param[4];
            int outMapW     = param[5];
            int convH       = param[6];
            int convW       = param[7];
            int convStride  = param[8];
            int expand      = param[9];

            int odeepIdx;
            int ideepIdx;
            int ohIdx;
            int owIdx;
            int phIdx;
            int pwIdx;

            int iod;
            int vod;
            int svh;

            int insize = inMapH * inMapW;
            int convsize = inMapNo * convH * convW;
            int convtsor = convH * convW;

            for (odeepIdx = 0; odeepIdx < outMapNo; odeepIdx++)
            {
                for (ohIdx = expand; ohIdx < outMapH - expand; ohIdx++)
                {
                    iod = odeepIdx * outMapH * outMapW + ohIdx * outMapW;
                    for (owIdx = expand; owIdx < outMapW - expand; owIdx++)
                    {
                        float *cu
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值