《Self-Attention Generative Adversarial Networks》阅读笔记+tensorflow实现
前言
这篇论篇幅比较短,核心就在于一个注意力机制,本篇文章主要是教大家如何利用tensorflow使用self-attention
Paper:airxiv地址
Code:pytorch版本、tensorflow版本
一、自注意力机制是什么?
其实下篇这个图并不是很需要看懂,网上也有其他人解释的比较好的,这里我的理解是大家只需要知道自注意力机制的本质思想就是想要整体数据集样本的大局观对图片中局部的内容进行优化。
二、使用步骤
1.在tensorflow代码版本中提供了attention和google_attention,源代码中使用了google_attention,我在采用了google_attention使用。
代码如下(示例),源代码使用了类的方式,这里为了方便使用,我改成了函数体的方式:
def google_attention(x, channels, scope='attention'):
with tf.variable_scope(scope):
batch_size, height, width, num_channels = x.get_shape().as_list()
f = conv(x, channels // 8, kernel=1, stride=1, sn=True, scope='f_conv') # [bs, h, w, c']
f = max_pooling(f)
g = conv(x, channels // 8, kernel=1, stride=1, sn=True, scope='g_conv') # [bs, h, w, c']
h = conv(x, channels // 2, kernel=1, stride=1, sn=True, scope='h_conv') # [bs, h, w, c]
h = max_pooling(h)
# N = h * w
s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
beta = tf.nn.softmax(s) # attention map
o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
o = tf.reshape(o, shape=[batch_size, height, width, num_channels // 2]) # [bs, h, w, C]
o = conv(o, channels, kernel=1, stride=1, sn=True, scope='attn_conv')
x = gamma * o + x
return x
2.在这里的conv、max_pooling可能会报错,大家可以改成自己在tensorflow中的函数,我是尊重源码,复制了ops.py文件
在神经网络中可以直接通过如下函数通用即可,在不同的任务中要可以在神经网络的中间或者后面放置,我在利用生成对抗网络做风格迁移时我放在了残差块之后。参数分别是输入图片,输入图片的通道数(即上一层输出的图片和输出的图片的通道数):
x=ops.google_attention(input=?, channels=?, scope='self_attention')
总结
欢迎大家提出质疑,相互学习,相互进步,如有转发请注明出处。