从代码角度学习理解Pytorch学习框架01: Variable的了解
# coding=utf-8
import torch
from torch.autograd import Variable
"""pytorch中Variable了解"""
"""
Variable是Pytorch中autograd自动微分模块的核心。
它封装了Tensor,支持几乎所有的tensor操作。
主要包含如下3个属性:
1. data: 保存Variable所包含的Tensor
2. grad: 保存data对应的梯度,grad也是一个Variable,而不是一个Tensor,和data的形状一样
3. grad_fn: 指向一个Function对象,这个Function用来反向传播计算输入的梯度
"""
def about_variable():
x = Variable(torch.ones(3, 2), requires_grad=True)
print('x: {}'.format(x))
print('x.data: {}'.format(x.data))
print('x.grad: {}'.format(x.grad))
# Variable和Tensor具有几乎一致的接口
aa_variable = Variable(torch.ones(3, 2))
print('torch.cos(aa_variable): {}'.format(torch.cos(aa_variable)))
print('torch.cos(aa_variable.data): {}'.format(torch.cos(aa_variable.data)))
if __name__ == '__main__':
about_variable()