学习pytorch的过程留意到的三种类型:
- torch.FloatTensor
- torch.autograd.variable.Variable
- torch.nn.parameter.Parameter(是Variable的子类)
如果在网络的训练过程中需要更新,就要定义为Parameter, 类似为W(权重)和b(偏置)也都是Parameter
Tensor的创建:
import torch
import torch.autograd as autograd
x = torch.randn((2, 2))
print(type(x))
Tensor转化为Variable:
var_x = autograd.

本文介绍了PyTorch中的三种核心类型:Tensor、Variable和Parameter。Tensor用于存储数据,不具备梯度计算功能;Variable是带有数据和梯度的Tensor,用于后向传播;Parameter是Variable的子类,专门用于网络训练中需要更新的权重和偏置,具备自动梯度计算能力。转换过程包括Tensor转Variable和Variable转Parameter。
最低0.47元/天 解锁文章
1221

被折叠的 条评论
为什么被折叠?



