今天在看DSSINet代码的ssim.py时,遇到了一个用法
class NORMMSSSIM(torch.nn.Module):
def __init__(self, sigma=1.0, levels=5, size_average=True, channel=1):
super(NORMMSSSIM, self).__init__()
self.sigma = sigma
self.window_size = 5
self.levels = levels
self.size_average = size_average
self.channel = channel
self.register_buffer('window', create_window(self.window_size, self.channel, self.sigma))
self.register_buffer('weights', torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]))
那么这个register_buffer()是干什么用呢?官方解释如下
nn.modules.module.py
Adds a persistent buffer to the module.向模块添加持久缓冲区。
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``