了解Variable
顾名思义, V a r i a b l e Variable Variable就是变量的意思,实际上也就是可以变化的量,区别于 i n t int int变量,其是一种可以变化的量,这正好就符合了反向传播,参数更新的属性。
具体来说,在
P
y
T
o
r
c
h
PyTorch
PyTorch中的
V
a
r
i
a
b
l
e
Variable
Variable就是一个存放会变化量的地理位置,里面的值会不停的发生变化,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化,那谁是里面的鸡蛋呢?自然就是
P
y
T
o
r
c
h
PyTorch
PyTorch里面的张量,(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式),如果用
V
a
r
a
i
b
l
e
Varaible
Varaible计算的话,那返回的也是一个同类型的
V
a
r
i
a
b
l
e
Variable
Variable。
用一个例子说明
V
a
r
i
a
b
l
e
Variable
Variable的定义:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True)
print(tensor)
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable)
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
注意 t e n s o r tensor tensor不能反向传播, V a r i a b l e Variable Variable可以反向传播。
Variable求梯度
V a r i a b l e Variable Variable计算时,其会逐渐的生成计算图,这个图就是将所有的计算节点都连接起来,最后进行误差反向传播传递的时候,一次性的将所有 V a r i a b l e Variable Variable里面的梯度都计算出来,而 t e n s o r tensor tensor就没有这个能力。
v_out.backward() # 模拟 v_out 的误差反向传递
print(variable.grad) # 初始 Variable 的梯度
'''
0.5000 1.0000
1.5000 2.0000
'''
获取Variable里面的数据
直接print(Variable) 只**会输出Variable形式的数据,**在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。
print(variable) # Variable 形式
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data) # 将variable形式转为tensor 形式
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data.numpy()) # numpy 形式
"""
[[ 1. 2.]
[ 3. 4.]]
"""
扩展
在
P
T
y
t
o
r
c
h
PTytorch
PTytorch计算图的特点总结如下:
a
u
t
o
g
r
a
d
autograd
autograd根据用户对
V
a
r
i
a
b
l
e
Variable
Variable的操作来构建计算图。
requires_grad
V a r i a b l e Variable Variable是默认不需要被求导的,即即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。
Volatile
variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。
retain_graph
多次反向传播(多层监督)时,梯度时累加的,一般来说,单次反向传播后,计算图会 f r e e free free掉,也就是反向传播的中间缓存会被清空,(这就是动态度的特点) 为进行多次反向传播需指定 r e t a i n g r a p h = T r u e retain_graph=True retaingraph=True来保存这些缓存。
- backward(): 反向传播,求解 V a r i a b l e Variable Variable的梯度,放在中间缓存中。