pytorch Variable用法
什么是 Variable
如果用numpy或者Tensor来实现神经网络,需要手动写出前向过程和反向过程。
对于简单的网络,反向过程中的导数容易求得,但是随着网络深度以及网络复杂度的增加,求出梯度的解析表达式是非常困难的。
PyTorch的包autograd提供了自动求导的功能,当使用autograd时,定义前向网络会生成 一个计算图,每个节点是一个Tensor,边表示由输入Tensor到输出Tensor的函数。
沿着计算图的反向传播可以很容易地计算出各个变量的梯度。在实现的时候,用到了Variable对象。
具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。
Variable的属性有三个:
data:Variable里Tensor变量的数值
grad:Variable反向传播的梯度
grad_fn:表示是通过什么操作得到这个变量的例如( 加减乘除、卷积、反置卷积)
二、Variable的创建和使用
我们定义一个 Variable:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
# 先生鸡蛋
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True)
print(tensor)
print(variable)
tensor([[1., 2.],
[3., 4.]])
tensor([[1., 2.],
[3., 4.]], requires_grad=True)
这里requires_grad=True意义是否对这个变量求梯度,默认的 Fa!se:
Variable的三个属性
我们依次打印它们:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
#创建Variable
a = Variable(torch.Tensor())
print(a)
b = Variable(torch.Tensor([[1, 2], [3, 4],[5, 6], [7, 8]]))
print(b)
print(b.data)
print(b.grad)
print(b.grad_fn)
tensor([])
tensor([[1., 2.],
[3., 4.],
[5., 6.],
[7., 8.]])
tensor([[1., 2.],
[3., 4.],
[5., 6.],
[7., 8.]])
None
None
Variable 计算梯度
我们再对比一下 tensor 的计算和 variable 的计算.
t_out = torch.mean(tensor*tensor) # x^2
v_out = torch.mean(variable*variable) # x^2
print(t_out)
print(v_out) # 7.5
tensor(7.5000)
tensor(7.5000, grad_fn=<MeanBackward0>)
到目前为止, 我们看不出什么不同, 但是时刻记住, Variable 计算时, 它在背景幕布后面一步步默默地搭建着一个庞大的系统, 叫做计算图, computational graph. 这个图是用来干嘛的? 原来是将所有的计算步骤 (节点) 都连接起来, 最后进行误差反向传递的时候, 一次性将所有 variable 里面的修改幅度 (梯度) 都计算出来, 而 tensor 就没有这个能力啦.
v_out = torch.mean(variable*variable) 就是在计算图中添加的一个计算步骤, 计算误差反向传递的时候有他一份功劳, 我们就来举个例子:
v_out.backward() # 模拟 v_out 的误差反向传递
# 下面两步看不懂没关系, 只要知道 Variable 是计算图的一部分, 可以用来传递误差就好.
# v_out = 1/4 * sum(variable*variable) 这是计算图中的 v_out 计算步骤
# 针对于 v_out 的梯度就是, d(v_out)/d(variable) = 1/4*2*variable = variable/2
print("\n:",variable.grad) # 初始 Variable 的梯度
: tensor([[0.5000, 1.0000],
[1.5000, 2.0000]])
标量求导计算图
我们先声明一个变量x,这里requires_grad=True意义是否对这个变量求梯度,默认的 Fa!se:
# 建立计算图
from torch.autograd import Variable
x = Variable(torch.Tensor([2]), requires_grad = True)
print(x)
tensor([2.], requires_grad=True)
我们再声明两个变量w和b:
w = Variable(torch.Tensor([3]),requires_grad = True)
print(w)
b = Variable(torch.Tensor([4]),requires_grad = True)
print(b)
tensor([3.], requires_grad=True)
tensor([4.], requires_grad=True)
我们再写两个变量y1和y2:
y1 = w * x + b
print(y1)
y2 = w * x + b * x
print(y2)
tensor([10.], grad_fn=<AddBackward0>)
tensor([14.], grad_fn=<AddBackward0>)
我们来计算各个变量的梯度,首先是y1:
#计算梯度
y1.backward()
print(x.grad)
print(w.grad)
print(b.grad)
tensor([3.])
tensor([2.])
tensor([1.])
代码为:
# 建立计算图
from torch.autograd import Variable
x = Variable(torch.Tensor([2]), requires_grad = True)
print(x)
w = Variable(torch.Tensor([3]), requires_grad = True)
print(w)
b = Variable(torch.Tensor([4]), requires_grad = True)
print(b)
y1 = w*x+b
print(y1)
y2 = w*x+b+x
print(y2)
# 计算梯度
y1.backward()
print(x.grad)
print(w.grad)
print(b.grad)
其中:
y1 = 3 * 2 + 4 = 10,
y2 = 3 * 2 + 4 * 2 = 14,
x的梯度是3因为是3 * x,
w的梯度是2因为w * 2,
b的梯度是1因为b * 1(* 1被省略)
三,矩阵求导计算图
如
# 矩阵求导
c = Variable(torch.randn(3), requires_grad = True)
print(c)
y3 = c*2
# y3.backward(torch.Tensor([1,1,1]))
# print(c.grad)
y3.backward(torch.Tensor([1,0.1,0.01]))
print(c.grad)
tensor([1.0374, 0.7317, 0.6659], requires_grad=True)
tensor([2.0000, 0.2000, 0.0200])
可以看到,c是一个1行3列的矩阵,因为y3 = c * 2,因此如果backward()里的参数为:
torch.FloatTensor([1, 1, 1])
则就是每个分量的梯度,但是传入的是:
torch.FloatTensor([1, 0.1, 0.01])
则每个分量梯度要分别乘以1,0.1和0.01
四、Variable放到GPU上执行
Tensor一样的道理,代码如下:
# Variable放在GPU上
if torch.cuda.is_available():
d = c.cuda()
print(d)
a tensor([ 0.0588, 0.6762, -0.1626], device='cuda:0', grad_fn=<CopyBackwards>)
五, Variable形式转化其他数据形式
tensor = torch.FloatTensor([[1,2],[3,4]])
variable = Variable(tensor, requires_grad=True)
print(variable) # Variable 形式
print(variable.data) # tensor 形式
print(variable.data.numpy()) # numpy 形式
tensor([[1., 2.],
[3., 4.]], requires_grad=True)
tensor([[1., 2.],
[3., 4.]])
[[1. 2.]
[3. 4.]]