neuraltalk2-代码解析-(5)-train.lua

赶紧完成这个neuraltalk2系列,这是这个系列的最后一篇博客,因为如果大家以后想在neuraltalk2的代码基础上实现对自己的image caption的实验,知道这几个文件是足够的,尤其是对我这种小白而言,torch是我登入deep learning这个领域的第一个着落点,将来我也肯定会尝试tensorflow等平台,提升代码能力也是我这学期的一个重要目标,不多说开始。

这个文件前面都是参数的设定,其中帮助信息已经写得十分详细了,我就不多说了。

我直接从eval_split开始。

  • eval_split(split,evalopt)
--这个函数是对valid集进行实验评估,其实可以对任意数据集做评估,这跟输入的参数split有关,split的实际值域可以理解为,"train","valid","test"
local function eval_split(split, evalopt)
  local verbose = utils.getopt(evalopt, 'verbose', true)
  --val_images_use表示的是val集所用图片的数量
  local val_images_use = utils.getopt(evalopt, 'val_images_use', true)
  protos.cnn:evaluate()
  protos.lm:evaluate()
  --重置split属性的迭代器,其作用跟据迭代器来访问split数据集中各数据
  loader:resetIterator(split) -- rewind iteator back to first datapoint in the split
  --n为计数器
  local n = 0
  --损失函数的总值
  local loss_sum = 0
  --损失的评估值
  local loss_evals = 0
  --最后的预测结果
  local predictions = {}
  --得到index索引到词的映射空间
  local vocab = loader:getVocab()
  while true do
    -- fetch a batch of data
    -- 取得一个batch_size的数据(图片加序列),由于一个图片需要扩展成seq_per_img来训练,提升训练的效率,所以实际图片向量的大小为batch_size*seq_per_img
    local data = loader:getBatch{batch_size = opt.batch_size, split = split, seq_per_img = opt.seq_per_img}
    --对图片进行预处理
    data.images = net_utils.prepro(data.images, false, opt.gpuid >= 0) -- preprocess in place, and don't augment
    --n为评估图片的总数量
    n = n + data.images:size(1)
    -- forward the model to get loss
    --feats是一个2维的matrix,第一维的大小为batch_size,第二维的大小为encoding_size
    local feats = protos.cnn:forward(data.images)
    --expanded_feats是一个2维matrix,第一维的大小为batch_size*seq_per_img,第二维的大小为encoding_size
    local expanded_feats = protos.expander:forward(feats)
    --经过语言模型,这里logprobs是一个table,每个元素为所预测的单词向量
    local logprobs = protos.lm:forward{expanded_feats, data.labels}
    --经过校准层,得最后的损失值
    local loss = protos.crit:forward(logprobs, data.labels)
    --累积损失值
    loss_sum = loss_sum + loss
    --评估数量加1
    loss_evals = loss_evals + 1
    -- forward the model to also get generated samples for each image
    -- 这里是将输入图片的特征完全转换为真实的预测序列,即文档
    -- seq为预测结果的索引值
    local seq = protos.lm:sample(feats)
    --sents为预测结果的文档
    local sents = net_utils.decode_sequence(vocab, seq)
    --储存预测结果
    for k=1,#sents do
      local entry = {image_id = data.infos[k].id, caption = sents[k]}
      table.insert(predictions, entry)
      if verbose then
        print(string.format('image %s: %s', entry.image_id, entry.caption))
      end
    end
    -- if we wrapped around the split or used up val imgs budget then bail
    -- ix0指向的当前评估的图像的索引编号
    local ix0 = data.bounds.it_pos_now
    -- ix1是图像索引编号的的最大值
    local ix1 = math.min(data.bounds.it_max, val_images_use)
    if verbose then
      print(string.format('evaluating validation performance... %d/%d (%f)', ix0-1, ix1, loss))
    end
    --没评估10次,回收垃圾
    if loss_evals % 10 == 0 then collectgarbage() end
    --如果已经遍历完了所有数据集中的数据,结束评估
    if data.bounds.wrapped then break end -- the split ran out of data, lets break out
    if n >= val_images_use then break end -- we've used enough images
  end
  --lang_stats为打分值
  local lang_stats
  if opt.language_eval == 1 then
    lang_stats = net_utils.language_eval(predictions, opt.id)
  end

  return loss_sum/loss_evals, predictions, lang_stats
end
  • lossfun
local function lossFun()
  --将cnn,lm的mode设定为(train=true),这对dropout有效
  protos.cnn:training()
  protos.lm:training()
  --将参数清零
  grad_params:zero()
  --如果需要微调cnn,则将cnn的参数也清零         
  if opt.finetune_cnn_after >= 0 and iter >= opt.finetune_cnn_after then
    cnn_grad_params:zero()
  end
  -- get batch of data
  -- 跟上面一样,取得数据  
  local data = loader:getBatch{batch_size = opt.batch_size, split = 'train', seq_per_img = opt.seq_per_img}
  -- 预处理img信息,下面很多与上个函数一样,不多说了
  data.images = net_utils.prepro(data.images, true, opt.gpuid >= 0) -- preprocess in place, do data augmentation
  -- data.images: Nx3x224x224 
  -- data.seq: LxM where L is sequence length upper bound, and M = N*seq_per_img
  -- forward the ConvNet on images (most work happens here)
  local feats = protos.cnn:forward(data.images)
-- we have to expand out image features, once for each sentence
  local expanded_feats = protos.expander:forward(feats)
  -- forward the language model
  local logprobs = protos.lm:forward{expanded_feats, data.labels}
  -- forward the language model criterion
  local loss = protos.crit:forward(logprobs, data.labels)

  -----------------------------------------------------------------------------
  -- Backward pass
  -----------------------------------------------------------------------------
  -- backprop criterion
  -- 这里进行梯度的传导,求出个梯度
  -- 校准层的输入梯度,校准层的输入为,LSTM层的输出
  local dlogprobs = protos.crit:backward(logprobs, data.labels)
  -- backprop language model
  -- 获得lm层的输入梯度,这里注意dexpand_feats为图像输入的梯度,ddumy则是一个无意义的matrix
  local dexpanded_feats, ddummy = unpack(protos.lm:backward({expanded_feats, data.labels}, dlogprobs))
  -- backprop the CNN, but only if we are finetuning
  -- 如果cnn需要微调则进行微调处理
  if opt.finetune_cnn_after >= 0 and iter >= opt.finetune_cnn_after then
    local dfeats = protos.expander:backward(feats, dexpanded_feats)
    local dx = protos.cnn:backward(data.images, dfeats)
  end

  -- clip gradients
  -- print(string.format('claming %f%% of gradients', 100*torch.mean(torch.gt(torch.abs(grad_params), opt.grad_clip))))
  -- 控制梯度的大小,clamp(a,b)这个函数是a为最小值,b为最大值,如果grad_params中的元素大于b则取b,小于a则取a
  grad_params:clamp(-opt.grad_clip, opt.grad_clip)
  -- 是否应用L2 regularization
  -- apply L2 regularization
  if opt.cnn_weight_decay > 0 then
    cnn_grad_params:add(opt.cnn_weight_decay, cnn_params)
    -- note: we don't bother adding the l2 loss to the total loss, meh.
    cnn_grad_params:clamp(-opt.grad_clip, opt.grad_clip)
  end
  -----------------------------------------------------------------------------

  -- and lets get out!
  local losses = { total_loss = loss }
  return losses
end

最后一个代码块,不分析了,如果想借用这个代码,最后的代码块也不需要改,写完这个系列突然觉得,自己有点蠢蠢的!!!,不说了,赶紧写完,我以后会发有营养的博客。加油!Fighting!

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值