学习pytorch的过程留意到的三种类型:
- torch.FloatTensor
- torch.autograd.variable.Variable
- torch.nn.parameter.Parameter(是Variable的子类)
如果在网络的训练过程中需要更新,就要定义为Parameter, 类似为W(权重)和b(偏置)也都是Parameter
Tensor的创建:
import torch
import torch.autograd as autograd
x = torch.randn((2, 2))
print(type(x))
Tensor转化为Variable:
var_x = autograd.