参考来源于:Pytorch Tutorial – Chongruo Wu
写的真的不错,在这里分享下:
首先是torch的三个主要模块:Tensor、 Variable 、Module。
-
Tensor 可以视为
ndarray
, 但是可以在GPU上做计算,比如下图中的cuda设置; Variable 是计算图上的一个节点,存储数据和梯度;Module含有神经网络的层,可用于存储可训练的网络参数等。 -
下图是cuda设置数据类型:
-
这个图说明torch.autograd.Variable:
-
下面的图说明Module用来写自己的神经网络类: