Torch7入门续集(五)----进一步了解optim

总说

现在看以前写的入门续集,觉得写的好烂,但我不想改了。在torch7学习(六)稍微提到了optim,觉得写的很不清楚,所以有了这篇。
入门续集到了这篇,个人认为看Torch框架的深度学习代码应该没啥大问题了

总览

x*, {f}, ... = optim.method(opfunc, x[, config][, state])

  • opfunc: 自定义的闭包,必须包含:我们计算的f 对于优化的参数x的求导的API,比如大部分我们是训练网络,然后我们的 f 就是loss,而x就是网络的weight。此时我们在opfunc中必须包含”API”, 这个“API”在这里类似是:
-- 这里的loss就是f
local loss = criterion:forward(predict, trainlabels)
local dloss_dpredict = criterion:backward(predict,trainlabels)

-- 这里调用backward,这个backward会计算gradWeight, gradBias以及gradInput.
local gradInput = net:backward(trainset, dloss_dpredict)
  • x: 需要优化的参数,这里的x必须是一维的!
  • config: 根据不同的优化方法,设置不同的选项。
  • state: 这个一般包含learningRate,learningRateDecay之类的。
  • x*: 其中 x* = argmin_x f(x)
  • {f}: 略

一般的写法:

require 'optim'

local optimState = {learningRate = 0.01}

local params, gramParams = net:getParameters()

function feval(params)
-- 无论如何,在f函数中,先要将需要“优化的参数的梯度”设置成0
   gradParams:zero()

-- 重新计算“需要优化的参数”。
   local outputs = model:forward(batchInputs)
   local loss = criterion:forward(outputs, batchLabels)
   local dloss_doutputs = criterion:backward(outputs, batchLabels)
   model:backward(batchInputs, dloss_doutputs)

-- 返回f值,就是loss,以及gramParams(这个必须是一维的)
   return loss, gradParams
end

for epoch = 1, 50 do

   -- 加载点数据,干点额外的事。

   -- 在最后一句调用optim.method
   optim.sgd(feval, params, optimState)
end

再次强调:backward只是调用每一层的 updateGradInput以及accGradParameters,并不会更新参数,只是计算参数的梯度,以及计算每一层的输入的梯度。

local fDx = function(x)
    netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
    netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)

    gradParametersD:zero()

    -- Real
    -- train netD with (real, real_label)
    local output = netD:forward(real_AB)
    local label = torch.FloatTensor(output:size()):fill(real_label)
    if opt.gpu>0 then 
        
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值