用来记录pytorch学习构建lenet5网络的过程中,遇到的不了解及一些知识。
torch.nn
-
Conv1d:一维卷积,常用于对文本数据。宽度卷积高度不卷积。可以词向量维度*句子最大长度作为输入。
-
Conv2d:二维卷积,用于图像数据。
-
logsoftmax:对softmax结果取log
-
torch.utils.data.DataLoader:数据集加载器。自动将数据分成batch,并打算顺序
-
nn.NLLLoss():常用于多分类任务的损失函数,常和logsoftmax搭配。
注:logsoftmax+NLLLoss就等同于交叉熵损失函数。 -
torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)
梯度下降法。params待优化参数,lr学习率,momentum动量因子默认0,weight_decay权重衰减,nesterov使用nesterov动量、默认false,
7.output.data.max(1, keepdim=True)[1]
其中1表示找第二维,max找最大值,即找第二维的最大值。keepdim为True表示保持维度、为False表示输出比输入少一个维度。