conv2d是常用的实现卷积的,Tensorflow调用Conv的api时,常用代码如下:
查看:https://tensorflow.google.cn/api_docs/python/tf/nn/conv2d
tf.nn.conv2d(
input,
filter,
strides,
padding,
use_cudnn_on_gpu=True,
data_format='NHWC',
dilations=[1, 1, 1, 1],
name=None
)
对于其中一个参数padding的理解为:
padding: A `string` from: `"SAME", "VALID"`.
The type of padding algorithm to use.
padding有两个方式可以选择:“SAME” and “VALID”
举两个栗子:
One:padding='SAME'
import tensorflow as tf
input = tf.Variable(tf.random_normal([1,5,5,3]))
filter = tf.Variable(tf.random_normal([3,3,3,7]))
result = tf.nn.conv2d(input, filter, strides=[1,2,2,1],padding='SAME')
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
print(sess.run(result))
print(result.shape)
sess.close()
结果为:
[[[[ 6.88815355e-01 -1.58929396e+00 -8.13352680e+00 3.47248018e-01
-2.10637522e+00 -2.47548366e+00 -3.29180861e+00]
[ -1.50164223e+00 -2.82424307e+00 -2.40781856e+00 -2.55665493e+00
-3.89841533e+00 -6.71445191e-01 3.10867667e+00]
[ -3.39479542e+00 -1.40321875e+00 2.29996824e+00 -3.98842275e-01
7.90905952e-03 -1.71421432e+00 -5.47636747e-01]]
[[ -1.07995415e+00 -2.21969414e+00 -1.43076777e-01 2.65041399e+00
-4.38491011e+00 -4.83550358e+00 8.30997753e+00]
[ 1.35791779e+00 -1.38357902e+00 -4.50581169e+00 1.22106361e+00
-1.36877072e+00 -1.19497585e+00 -3.64005876e+00]
[ -3.07881045e+00 1.33630781e+01 -4.33032846e+00 1.98507690e+00
-1.34837186e+00 -3.44964921e-01 -5.76371312e-01]]
[[ -4.02724743e-01 -3.08082283e-01 1.51205099e+00 -2.11967897e+00
8.77675891e-01 -3.89271736e-01 1.28933489e+00]
[ 1.05681574e+00 3.83993292e+00 1.46158600e+00 5.12251711e+00
-4.37659168e+00 -5.88564873e-02 8.72927666e-01]
[ 3.13625002e+00 -2.52725768e+00 -1.89247894e+00 -2.89734745e+00
2.49475980e+00 -7.85117006e+00 4.73596001e+00]]]]
(1, 3, 3, 7)
Two:padding='VALID'
import tensorflow as tf
input = tf.Variable(tf.random_normal([1,5,5,3]))
filter = tf.Variable(tf.random_normal([3,3,3,7]))
result = tf.nn.conv2d(input, filter, strides=[1,2,2,1],padding='VALID')
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
print(sess.run(result))
print(result.shape)
sess.close()
结果为:
[[[[ 3.30246162e+00 1.00174313e+01 1.02988682e+01 -3.38870287e+00
3.57620907e+00 9.25950432e+00 1.40226996e+00]
[ 2.39865661e+00 4.90117121e+00 6.27546692e+00 -7.14295626e+00
-1.87810266e+00 4.73461962e+00 -8.87438393e+00]]
[[ 5.66498578e-01 1.21167574e+01 -2.98488545e+0