Pytorch中的register_buffer
1.register_buffer( )的使用
回顾模型保存:torch.save(model.state_dict()),model.state_dict()是一个字典,里边存着我们模型各个部分的参数。
在model中,我们需要更新其中的参数,训练结束将参数保存下来。但在某些时候,我们可能希望模型中的某些参数参数不更新(从开始到结束均保持不变),但又希望参数保存下来(model.state_dict() ),这是我们就会用到 register_buffer() 。
随着例子边看边讲
例子1:使用类成员变量(类成员变量并不会在我们的model.state_dict(),即无法保存)
成员变量(self.tensor)在前向传播中用到,希望它也能保存下来,但他不在我们的state_dict中。
class my_model(nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, 1, 1)
self.tensor = torch.randn(size=(1, 1, 5, 5)) # 成员变量
def forward(self, x):
return self.conv(x) + self.tensor
x = torch.randn(size=(1, 1, 5, 5))
model = my_model()
model(x)
print(model.state_dict())
print('..........')
print(model.tensor)
# OrderedDict([('conv.weight', tensor([[[[ 0.1797, -0.1616, 0.1784],
# [ 0.2831, -0.0466, 0.1068],
# [ 0.0733, -0.2953, -0.2349]]]])), ('conv.bias', tensor([-0.3234]))])
# ..........
# tensor([[[[-0.0058, 0.3659, 0.8884, -0.9833, 0.4962],
# [ 0.1103, 0.5936, 0.2021, -1.8994, 0.1486],
# [ 0.9335, 0.1341, 0.1928, 0.5942, 0.7708],
# [-0.8632, 1.4890, -0.3192, 1.2532, 0.8017],
# [ 0.6020, 0.0112, 0.4995, -0.7160, -1.1624]]]])
例子2:使用类成员变量(类成员变量并不会随着model.cuda()复制到gpu上)
将上一个例子中的模型复制到GPU上,但成员变量并不会随着model.cuda()复制到gpu上。torch中如果有数据不在同一个“地方”进行“运算”,程序会报错, 即self.tensor在 “ cpu ” 上,模型和 x 在 “ cuda:0 ” 上。
class my_model(nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, 1, 1)
self.tensor = torch.randn(size=(1, 1, 5, 5)) # 成员变量
def forward(self, x):
return self.conv(x) + self.tensor
x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
print(model.state_dict())
print('..........')
print(model.tensor)
# 报错!!!
# RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
例子3:使用register_buffer()
self.register_buffer(‘my_buffer’, self.tensor):my_buffer是名字,str类型;self.tensor是需要进行register登记的张量。这样我们就得到了一个新的张量,这个张量会保存在model.state_dict()中,也就可以随着模型一起通过.cuda()复制到gpu上。
class my_model(nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, 1, 1)
self.tensor = torch.randn(size=(1, 1, 5, 5))
self.register_buffer('my_buffer', self.tensor)
def forward(self, x):
return self.conv(x) + self.my_buffer # 这里不再是self.tensor
x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
print(model.state_dict())
print('..........')
print(model.tensor)
print(model.my_buffer)
# OrderedDict([('my_buffer', tensor([[[[ 0.0719, -0.5347, 0.5229, -0.5599, -0.5907],
# [-0.2743, -0.6166, -1.6723, -0.0386, 0.9706],
# [-1.0789, 0.9852, 0.1703, -0.6299, -0.5167],
# [ 0.4972, -0.9745, -0.3185, 0.3618, 0.2458],
# [ 1.5783, -0.5800, 0.1895, -0.9914, 1.1207]]]], device='cuda:0')), ('conv.weight', tensor([[[[-0.0192, 0.0500, 0.0635],
# [ 0.3025, -0.2644, 0.2325],
# [ 0.0806, 0.0457, -0.0427]]]], device='cuda:0')), ('conv.bias', tensor([-0.3074], device='cuda:0'))])
# ..........
# tensor([[[[ 0.0719, -0.5347, 0.5229, -0.5599, -0.5907],
# [-0.2743, -0.6166, -1.6723, -0.0386, 0.9706],
# [-1.0789, 0.9852, 0.1703, -0.6299, -0.5167],
# [ 0.4972, -0.9745, -0.3185, 0.3618, 0.2458],
# [ 1.5783, -0.5800, 0.1895, -0.9914, 1.1207]]]])
# tensor([[[[ 0.0719, -0.5347, 0.5229, -0.5599, -0.5907],
# [-0.2743, -0.6166, -1.6723, -0.0386, 0.9706],
# [-1.0789, 0.9852, 0.1703, -0.6299, -0.5167],
# [ 0.4972, -0.9745, -0.3185, 0.3618, 0.2458],
# [ 1.5783, -0.5800, 0.1895, -0.9914, 1.1207]]]], device='cuda:0')
总结
成员变量:不更新,但是不算是模型中的参数(model.state_dict())
通过register_buffer()登记过的张量:会自动成为模型中的参数,随着模型移动(gpu/cpu)而移动,但是不会随着梯度进行更新。
2.Parameter与Buffer
模型保存下来的参数有两种:一种是需要更新的Parameter,另一种是不需要更新的buffer。在模型中,利用backward反向传播,可以通过requires_grad来得到buffer和parameter的梯度信息,但是利用optimizer进行更新的是parameter,buffer不会更新,这也是两者最重要的区别。这两种参数都存在于model.state_dict()的OrderedDict中,也会随着模型“移动”(model.cuda())。
2.1 model.buffers()和model.named_buffers : 对模型中的buffer进行访问
与model.parameters()和model.named_parameters()相同,只是一个是对模型中的parameter访问,一个是对模型中的buffer访问。
class my_model(nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, 1, 1)
self.tensor = torch.randn(size=(1, 1, 5, 5))
self.tensor2 = torch.randn(size=(1, 1))
self.register_buffer('my_buffer', self.tensor)
self.register_buffer('my_buffer2', self.tensor2)
def forward(self, x):
return self.conv(x) + self.my_buffer
x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
for para in model.parameters():
print(para)
print('.............................................................')
for buff in model.buffers():
print(buff)
# Parameter containing:
# tensor([[[[ 0.1019, 0.3182, 0.1563],
# [-0.0207, -0.0562, 0.1807],
# [ 0.2703, -0.1186, -0.2867]]]], device='cuda:0', requires_grad=True)
# Parameter containing:
# tensor([-0.0992], device='cuda:0', requires_grad=True)
# .............................................................
# tensor([[[[ 1.3138, 1.3372, 1.6745, -0.8393, -0.1983],
# [-1.3365, 1.0321, -0.7752, 1.4250, 0.9376],
# [-0.9306, 0.1586, 0.5963, -1.0124, -0.6470],
# [ 0.6429, -1.1386, 0.8107, -0.8500, 0.4866],
# [ 0.0342, 1.5359, 0.6636, 0.2488, 0.0490]]]], device='cuda:0')
# tensor([[0.7401]], device='cuda:0')
2.2 Buffer变量可以通过backward()得到梯度信息
buffer变量和parameter变量一样,都可以通过backward()得到梯度信息,但区别是优化器optimizer更新的parameter变量,所以buffer并不会更新。
class my_model(nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.tensor = torch.randn(size=(2, 2))
self.register_buffer('my_buffer', self.tensor)
def forward(self, x):
return self.my_buffer * x
x = torch.randn(size=(3, 2, 2))
print(x.sum(0))
model = my_model()
for i in model.buffers():
i.requires_grad = True
y = model(x)
y.sum().backward()
print(model.my_buffer.grad)
# tensor([[ 0.9373, -1.0798],
# [-2.3031, 3.7299]])
# tensor([[ 0.9373, -1.0798],
# [-2.3031, 3.7299]])
2.3 Buffer变量不需要求梯度时,可通过Parameter代替
在构造模型时候,可以将某些Parameter从模型中通过“ .detach() ” 方法或直接将Parameter的requires_grad设置为False,使得此变量不求梯度,也可达到不更新的效果。
- 通过nn.Paramter()将张量设置为变量,同时设置requires_grad为False
- 这个变量也会随着模型保存,并且随着模型“移动”
- 可达到与buffer相同的效果
为什么要存在buffer:
buffer与parameter具有 “同等地位”,所以将某些不需要更新的变量“拿出来”作为buffer,可能更方便操作,可读性也更高,对Paramter的各种操作(固定网络的等)可能也不会“误伤到” buffer这种变量。buffer最重要的意义应该是需要得到梯度信息时,不会更新因为optimizer而更新,这也是parameter所不能代替的。
class my_model(nn.Module):
def __init__(self):
super(my_model, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, 1, 1)
self.tensor = nn.Parameter(torch.randn(size=(1, 1, 5, 5)), requires_grad=False)
def forward(self, x):
return self.conv(x) + self.tensor
x = torch.randn(size=(1, 1, 5, 5))
x = x.to('cuda')
model = my_model().cuda()
model(x)
for para in model.named_parameters():
print(para)
# ('tensor', Parameter containing:
# tensor([[[[ 0.3341, 1.1750, -1.9723, -1.6728, -0.2374],
# [-0.6646, 0.5763, -1.5781, 0.5802, 1.3265],
# [-0.0238, 0.3929, 1.0691, 2.0344, -0.7371],
# [-1.5995, -0.0445, 0.6577, 0.5779, 0.7600],
# [-0.6772, 1.6578, -0.8476, -0.7227, -0.5070]]]], device='cuda:0'))
# ('conv.weight', Parameter containing:
# tensor([[[[-0.3241, -0.3318, 0.0154],
# [ 0.0100, 0.0003, -0.0430],
# [-0.3331, -0.2996, -0.1164]]]], device='cuda:0', requires_grad=True))
# ('conv.bias', Parameter containing:
# tensor([0.2876], device='cuda:0', requires_grad=True))
3.BN中的参数
最近发现bn中的running_mean,running_var, num_batches_tracked这三个参数是buffer类型的,这样既可以用state_dict()保存,也不会随着optimizer更新。
此外,我们要注意,state_dict()只会保存parameters和buffers类型的变量,如果我们有变量没有转成这两种类型,最后是不会被保存的!!!
class network(nn.Module):
def __init__(self):
super(network, self).__init__()
self.conv = nn.Conv2d(1, 1, 1, padding=0)
self.bn = nn.BatchNorm2d(2)
def forward(self, x):
return self.bn(self.conv(x))
net = network()
for n, a in net.named_buffers():
print(n, a)
print('.........')
for w in net.parameters():
print(w)
print('.........')
for v in net.state_dict():
print(v)
# bn.running_mean tensor([0., 0.])
# bn.running_var tensor([1., 1.])
# bn.num_batches_tracked tensor(0)
# .........
# Parameter containing:
# tensor([[[[0.1984]]]], requires_grad=True)
# Parameter containing:
# tensor([0.4412], requires_grad=True)
# Parameter containing:
# tensor([1., 1.], requires_grad=True)
# Parameter containing:
# tensor([0., 0.], requires_grad=True)
# .........
# conv.weight
# conv.bias
# bn.weight
# bn.bias
# bn.running_mean
# bn.running_var
# bn.num_batches_tracked