目录
Modulating early visual processing by language
- paper: Modulating early visual processing by language
- code: https://github.com/ap229997/Conditional-Batch-Norm
Conditional Batch Normalization (CBN) 的概念就来源于这篇论文
Conditional Batch Normalization (CBN)
y
=
x
−
E
[
x
]
Var
[
x
]
+
ϵ
⋅
γ
p
r
e
d
+
β
p
r
e
d
y=\frac{x-\mathbb E[x]}{\sqrt{\text{Var}[x]+\epsilon}}\cdot\gamma_{pred}+\beta_{pred}
y=Var[x]+ϵx−E[x]⋅γpred+βpred其中,
γ
p
r
e
d
\gamma_{pred}
γpred (scale) 和
β
p
r
e
d
\beta_{pred}
βpred (bias) 不再是直接由损失函数反向传播学习得到,而是把 feature 输入一个 MLP,前向传播得到的网络输出。由于
γ
p
r
e
d
\gamma_{pred}
γpred 和
β
p
r
e
d
\beta_{pred}
βpred 依赖于输入的 feature,因此就称为 Conditional Batch Normalization
Modulating early visual processing by language
- 这篇文章改进了一个基于图片的问答系统 (VQA: Visual Question Answering)。系统的输入为一张图片和一个针对图片的问题,系统输出问题的答案,如下图所示:
这类系统通常是这样设计的:一个预训练的图像识别网络,例如 ResNet,用于提取图片特征;一个 sequential 模型,例如 LSTM、GRU 等,用于提取句子的特征,并根据句子预测应该关注图片的什么位置(attention);将语言特征、由 attention 加权过后的图片特征结合起来,共同输入一个网络,最终输出问题的答案。如上图左侧所示,LSTM 提取的特征只在 ResNet 的顶层才和图片特征结合起来,因为通常意义上讲,神经网络的底层提取的是基础的几何特征,顶层是有具体含义的语义特征,因此,应该把语言模型提取的句子特征在网络顶层和图片特征结合
- 然而作者认为,底层的图片特征也应该结合语言特征。理由是,神经科学证明:语言会帮助图片识别。例如,如果事先告诉一个人关于图片的内容,然后再让他看图片,那么这个人识别图片的速度会大大加快。因此,作者首创了将图片底层信息和语言信息结合的模型,如上图右侧所示。首先,ResNet 是预训练的网络,用于提取图片特征,因此不能轻易修改里面 filter 的参数,而其中的 BN 层有两组参数 scale 和 bias,用于对 feature map 施加缩放和偏置操作。这两个参数量不大,而且从含义上讲可以解释为:强调 feature map 的某部分 channel,忽略另外一些 channel。柿子捡软的捏,作者决定通过修改 scale 和 bias 的方式,达到有针对性地提取图片部分信息的目的,而修改的方式就是用 LSTM 提取的句子特征。例如上图,输入的句子问:伞上下颠倒了吗?LSTM 很大概率会提取出关键词:伞,把这个关键词的特征作为条件,输入到 MLP 中,输出新的权重 bias 和 scale,通过训练,这些权重最后将会有针对性地强调图片特征中与伞有关的 channel,而忽略与伞无关的 channel。而由于 ResNet 是预训练网络,即便是里面的 BN 层的参数,也是轻易不能动的。因此,作者没有直接用 MLP 的输出作为 BN 层新的 scale 和 bias,而是作为一个小的增量,加在原来的参数上:
{ γ n e w = γ + Δ γ β n e w = β + Δ β \left\{\begin{array}{l} \gamma_{n e w}=\gamma+\Delta \gamma \\ \beta_{n e w}=\beta+\Delta \beta \end{array}\right. {γnew=γ+Δγβnew=β+Δβ这个想法用最小的代价 (只修改了 BN 层参数),在图像的底层 feature 中结合了自然语言信息,取得了很好的表现
Categorical Conditional Batch Normalization
Motivation
- 在 conditional generative model 里面,存在一个隐隐让人不安的问题:一个 batch 里面不同类别的训练数据,放在一起做 Batch Normalization 不太妥当。因为不同类别的数据理应对应不同的均值和方差,其归一化、放缩、偏置也应该不同
- 针对这个问题,一个解决方案是不再考虑整个 batch 的统计特征,各个图像只在自己的 feature map 内部归一化,例如采用 Instance Normalization 和 Layer Normalization 来代替 BN。但是这些替代品的表现都不如 BN 稳定,接受程度不如 BN 高。这时就可以使用 conditional BN: 图片的类别信息也可以作为 condition 来预测 BN 层的参数。cGANs With Projection Discriminator 和 Self-Attention GANs 都借鉴了 CBN 里面的 condition 的思想,稍加修改,用在了自己的 conditional GAN 模型中
Categorical Conditional Batch Normalization
- Modulating early visual processing by language 一文中,由于使用了预训练的 ResNet,不敢对预训练网络 BN 层的参数做大修改,因此 MLP 的输出为 BN 层参数的增量,而不是直接输出新的 BN 层参数。conditional GANs 没有用到预训练网络,因此没有了历史包袱,直接用图片的 categorical 信息,预测新的 scale 和 bias。
- 接下来我们将研究其具体的实现,代码来自 Github
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import
class ConditionalBatchNorm2d(nn.BatchNorm2d):
"""Conditional Batch Normalization"""
def __init__(self, num_features, eps=1e-05, momentum=0.1,
affine=False, track_running_stats=True):
super(ConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
# forward 方法中增加了 weight 和 bias 参数
def forward(self, input, weight, bias, **kwargs):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
output = F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
# 主要是在 nn.BatchNorm2d 的基础上增加了下面几行
if weight.dim() == 1:
weight = weight.unsqueeze(0)
if bias.dim() == 1:
bias = bias.unsqueeze(0)
size = output.size()
weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
return weight * output + bias
class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):
def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,
affine=False, track_running_stats=True):
super(CategoricalConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
# nn.Embedding 层负责把图片的 label 转换成 dense 向量
self.weights = nn.Embedding(num_classes, num_features)
self.biases = nn.Embedding(num_classes, num_features)
# 初始化 self.weights 和 self.bias,分别把它们初始化为全 1 和全 0
self._initialize()
def _initialize(self):
init.ones_(self.weights.weight.data)
init.zeros_(self.biases.weight.data)
def forward(self, input, c, **kwargs):
weight = self.weights(c)
bias = self.biases(c)
return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)