How Can We Be So Dense? The Benefits of Using Highly Sparse Representations
论文复现。
阅读本文前请先阅读论文,本文简化了论文提供的代码,并对论文结果进行了复现和测试。
论文描述了使用稀疏表征的网络对于噪音的鲁棒性。
具体而言,论文使用一个具有一个卷积层+两个全连接层的网络在原始mnist上进行训练,然后在添加了噪音的测试集上进行测试。当使用稀疏连接的卷积层和全连接层时,相比于原始网络,稀疏网络表现出强大的对噪音的鲁棒性。
K_winner
K_winner是一个激活函数,用于替换relu。它的基本原理和relu一样,但是在输出时,不是像relu一样,大于阈值(一般是0)就输出。而是将输入排序,将最大的k个输出。换种方式,可以看出其设定了一个动态的阈值,使得有k个输入大于该阈值而可以输出。通过这种方式,控制了网络输出的稀疏程度,在本例中,1000个输入仅有100个可以输出。
其反向传播的实现机制还是和relu相同。对于k个输出,损失是1,其它的损失是0
import torch
class k_winners(torch.autograd.Function):
@staticmethod
def forward(ctx, x, dutyCycles, k, boostStrength):
if boostStrength > 0.0:
targetDensity = float(k) / x.size(1)
boostFactors = torch.exp((targetDensity - dutyCycles) * boostStrength)
boosted = x.detach() * boostFactors
else:
boosted = x.detach()
# Take the boosted version of the input x, find the top k winners.
# Compute an output that contains the values of x corresponding to the top k
# boosted values
res = torch.zeros_like(x)
topk, indices = boosted.topk(k, sorted=False)
for i in range(x.shape[0]):
res[i, indices[i]] = x[i, indices[i]]
ctx.save_for_backward(indices)
return res
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass, we set the gradient to 1 for the winning units, and 0
for the others.
"""
indices, = ctx.saved_tensors
grad_x = torch.zeros_like(grad_output, requires_grad=True)
# Probably a better way to do it, but this is not terrible as it only loops
# over the batch size.
for i in range(grad_output.size(0)):
grad_x[i, indices[i]] = grad_output[i, indices[i]]
return grad_x, None, None, None
class k_winners2d(torch.autograd.Function):
"""
A K-winner take all autograd function for CNN 2D inputs (batch, Channel, H, W).
.. seealso::
Function :class:`k_winners`
"""
@staticmethod
def forward(ctx, x, dutyCycles, k, boostStrength):
batchSize = x.shape[0]
if boostStrength > 0.0:
targetDensity = float(k) / (x.shape[1] * x.shape[2] * x.shape[3])
boostFactors = torch.exp((targetDensity - dutyCycles) * boostStrength)
boosted = x.detach() * boostFactors
else:
boosted = x.detach()
# Take the boosted version of the input x, find the top k winners.
# Compute an output that only contains the values of x corresponding to the top k
# boosted values. The rest of the elements in the output should be 0.
boosted = boosted.reshape((batchSize, -1))
xr = x.reshape((batchSize, -1))
res = torch.zeros_like(boosted)
topk, indices = boosted.topk(k, dim=1, sorted=False)
res.scatter_(1, indices, xr.gather(1, indices))
res = res.reshape(x.shape)
ctx.save_for_backward(indices)
return res
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass, we set the gradient to 1 for the winning units, and 0
for the others.
"""
batchSize = grad_output.shape[0]
indices, = ctx.saved_tensors
g = grad_output.reshape((batchSize, -1))
grad_x = torch.zeros_like(g, requires_grad=False)
grad_x.scatter_(1, indices, g.gather(1, indices))
grad_x = grad_x.reshape(grad_output.shape)
return grad_x, None, None, None
稀疏卷积和全连接层
SparseWeightNet
方法很简单,为了减少网络的连接,我们关闭输入到输出之间的部分连接,这个比例是50%。具体来说将这些关闭的连接在每次正向传播时的权重置0.
这与dropout有本质区别,dropout切断连接是随机的,每次都不同。SparseWeightNet
是真正的稀疏网络,它的输入和输出的连接自始至终是不同的。
论文和测试结果显示稀疏的卷积层并没有太影响性能。
import abc
import math
import numpy as np
import torch
import torch.nn as nn
def rezeroWeights(m):
if isinstance(m, SparseWeightsBase):
if m.training:
m.rezeroWeights()
def normalizeSparseWeights(m):
"""
凯明初始化的意义在于使其导数的期望不为0
由于用的python2.7,这里的除法应该是整数
"""
if isinstance(m, SparseWeightsBase):
_, inputSize = m.module.weight.shape
fan = int(inputSize * m.weightSparsity)
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain // np.math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
nn.init.uniform_(m.module.weight, -bound, bound)
if m.module.bias is not None:
bound = 1 // math.sqrt(fan)
nn.init.uniform_(m.module.bias, -bound, bound)
#这里创建了一个抽象类,为了从relu扩展到更多种类的层
class SparseWeightsBase(nn.Module):
__metaclass__ = abc.ABCMeta
def __init__(self, module, weightSparsity):
super(SparseWeightsBase, self).__init__()
assert 0 < weightSparsity < 1
self.module = module
self.weightSparsity = weightSparsity
self.register_buffer("zeroWts", self.computeIndices())
self.rezeroWeights()
def forward(self, x):
if self.training:
self.rezeroWeights()
return self.module.forward(x)
@abc.abstractmethod
def computeIndices(self):
"""
For each unit, decide which weights are going to be zero
:return: tensor indices for all non-zero weights. See :meth:`rezeroWeights`
"""
raise NotImplementedError
@abc.abstractmethod
def rezeroWeights(self):
"""
Set the previously selected weights to zero. See :meth:`computeIndices`
"""
raise NotImplementedError
class SparseWeights(SparseWeightsBase):
def __init__(self, module, weightSparsity):
"""
model = nn.Linear(784, 10)
model = SparseWeights(model, 0.4)
"""
super(SparseWeights, self).__init__(module, weightSparsity)
def computeIndices(self):
# For each unit, decide which weights are going to be zero
outputSize, inputSize = self.module.weight.shape
numZeros = int(round((1.0 - self.weightSparsity) * inputSize))
outputIndices = np.arange(outputSize)
#哇。产生outputIndices个,inputSize的随机排列,取numzeros个。这样就选出了numzeros个取0的下标
inputIndices = np.array([np.random.permutation(inputSize)[:numZeros]
for _ in outputIndices], dtype=np.int64)
#对于输入的每一位,指示那些输入下标应该被置零
# Create tensor indices for all non-zero weights
zeroIndices = np.empty((outputSize, numZeros, 2), dtype=np.int64)
zeroIndices[:, :, 0] = outputIndices[:, None]
zeroIndices[:, :, 1] = inputIndices
#刚好填入numzeros个位置
zeroIndices = zeroIndices.reshape(-1, 2)
return torch.from_numpy(zeroIndices.transpose())
def rezeroWeights(self):
zeroIdx = (self.zeroWts[0], self.zeroWts[1])
self.module.weight.data[zeroIdx] = 0.0
class SparseWeights2d(SparseWeightsBase):
def __init__(self, module, weightSparsity):
super(SparseWeights2d, self).__init__(module, weightSparsity)
def computeIndices(self):
# For each unit, decide which weights are going to be zero
inChannels = self.module.in_channels
outChannels = self.module.out_channels
kernelSize = self.module.kernel_size
inputSize = inChannels * kernelSize[0] * kernelSize[1]
numZeros = int(round((1.0 - self.weightSparsity) * inputSize))
outputIndices = np.arange(outChannels)
inputIndices = np.array([np.random.permutation(inputSize)[:numZeros]
for _ in outputIndices], dtype=np.int64)
# Create tensor indices for all non-zero weights
zeroIndices = np.empty((outChannels, numZeros, 2), dtype=np.int64)
zeroIndices[:, :, 0] = outputIndices[:, None]
zeroIndices[:, :, 1] = inputIndices
zeroIndices = zeroIndices.reshape(-1, 2)
return torch.from_numpy(zeroIndices.transpose())
def rezeroWeights(self):
zeroIdx = (self.zeroWts[0], self.zeroWts[1])
#print(self.zeroWts[0].type())
self.module.weight.data.view(self.module.out_channels, -1)[zeroIdx] = 0.0
给数据添加噪音
from __future__ import print_function
import numpy as np
class RandomNoise(object):
def __init__(self,noiselevel=0.0,whiteValue=0.1307 + 2*0.3081,):
self.noiseLevel = noiselevel
self.whiteValue = whiteValue
self.iteration = 0
def __call__(self, image):
self.iteration += 1
a = image.view(-1)
numNoiseBits = int(a.shape[0] * self.noiseLevel)
noise = np.random.permutation(a.shape[0])[0:numNoiseBits]
a[noise] = self.whiteValue
return image
我们使用以下代码进行测试。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch_mnist.k_winners import k_winners2d
from torch_mnist import sparse_weights
import skimage
from torch_mnist.image_transforms import RandomNoise
batch_size = 64
NOISE_VALUES = ["0.0", "0.05", "0.1", "0.15", "0.2", "0.25", "0.3", "0.35",
"0.4", "0.45", "0.5","0.55","0.6","0.65","0.7","0.75","0.8"]
transform_ = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data/',
train=True,
transform=transform_,
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
class SparseNet(nn.Module):
def __init__(self,sparseCNN=False,k_winner=False,sparseLinear=False):
super(SparseNet, self).__init__()
self.sparseCNN = sparseCNN
self.k_winner = k_winner
self.sparseLinear = sparseLinear
self.conv1 = nn.Conv2d(in_channels=1, out_channels=30, kernel_size=5)
if self.sparseCNN :
self.conv1 = sparse_weights.SparseWeights2d(self.conv1,weightSparsity=0.5)
self.mp = nn.MaxPool2d(2)
self.fc1 = nn.Linear(4320, 300)
if self.sparseLinear:
self.fc1 = sparse_weights.SparseWeights(self.fc1,weightSparsity=0.5)
self.fc2 = nn.Linear(300, 10)
def forward(self, x):
in_size = x.size(0)
x = self.mp(self.conv1(x))
if self.k_winner:
x = k_winners2d.apply(x,1000,100,0)
else:
x = F.relu(x)
x = x.view(in_size, -1) # flatten the tensor 相当于resharp
x = self.fc1(x)
x = self.fc2(x)
return F.log_softmax(x,dim=1) #64*10
model = SparseNet(sparseCNN=True,k_winner=True,sparseLinear=True)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = F.nll_loss(output, target)
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
optimizer.zero_grad() # 所有参数的梯度清零
loss.backward() #即反向传播求梯度
optimizer.step() #调用optimizer进行梯度下降更新参数
def test(noise_idx):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
noise = float(NOISE_VALUES[noise_idx])
transform.transforms.append(RandomNoise(noise, whiteValue=0.1307 + 2 * 0.3081))
test_dataset = datasets.MNIST(root='./data/',
train=False,
transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
test_loss = 0
correct = 0
for data, target in test_loader:
with torch.no_grad():
data, target = Variable(data), Variable(target)
output = model(data)
# sum up batch loss
test_loss += F.cross_entropy(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
test_loss /= len(test_loader.dataset)
print(correct.item()/100)
for epoch in range(1):
train(epoch)
for i in range(17):
test(i)
我们对测试级添加0-80%不等的噪音,获得了以下结果。
稀疏网络显示出对噪音惊人的容忍程度。
噪音 | denseNet | sparseNet |
0 | 96.84 | 95.38 |
0.05 | 96.55 | 95.13 |
0.1 | 96.2 | 94.55 |
0.15 | 95.56 | 94.02 |
0.2 | 94.9 | 93.84 |
0.25 | 92.64 | 92.72 |
0.3 | 89.22 | 92.25 |
0.35 | 83.86 | 91.06 |
0.4 | 77.36 | 90.04 |
0.45 | 69.36 | 88.89 |
0.5 | 59.37 | 87.58 |
0.55 | 48.66 | 85.35 |
0.6 | 38.23 | 82.98 |
0.65 | 28.23 | 80.12 |
0.7 | 21.49 | 75.18 |
0.75 | 15.63 | 69.1 |
0.8 | 12.19 | 60.07 |