Tensorflow--池化操作的梯度

Tensorflow–池化操作的梯度

池化操作的梯度分两部分介绍,第一部分介绍平均值池化的梯度计算,第二部分介绍最大值池化的梯度计算

一.平均值池化的梯度

利用计算梯度的函数gradients实现上述示例,具体代码如下:

import tensorflow as tf
import numpy as np

# x是1个3行3列1深度的张量
x=tf.placeholder(tf.float32,(1,3,3,1))

# 2x2的掩码,步长是(1,1,1,1)的valid平均值池化操作
sigma=tf.nn.avg_pool(x,(1,2,2,1),(1,1,1,1),'VALID')

# 构造一个函数F:池化结果的和
F=tf.reduce_sum(sigma)

session=tf.Session()

xvalue=np.random.randn(1,3,3,1)
grad=tf.gradients(F,[sigma,x])
results=session.run(grad,{x:xvalue})

print("---针对sigma的梯度---:")
print(results[0])
print("---针对x的梯度---:")
print(results[1])
---针对sigma的梯度---:
[[[[1.]
   [1.]]

  [[1.]
   [1.]]]]
---针对x的梯度---:
[[[[0.25]
   [0.5 ]
   [0.25]]

  [[0.5 ]
   [1.  ]
   [0.5 ]]

  [[0.25]
   [0.5 ]
   [0.25]]]]

二.最大值池化的梯度

import tensorflow as tf

# 初始化x的值
x=tf.Variable(tf.constant([
                           [
                           [[8],[2],[9],[3]],
                           [[4],[6],[7],[10]],
                           [[20],[13],[1],[5]],
                           [[12],[18],[19],[14]]
                           ]
                           ],tf.float32),dtype=tf.float32)

# 2x2的掩码,步长为2x2的最大值池化操作
x_maxPool=tf.nn.max_pool(x,(1,2,2,1),(1,2,2,1),'VALID')

# 对以上最大值池化结果计算其平方和
F=tf.reduce_sum(tf.square(x_maxPool))

session=tf.Session()
session.run(tf.global_variables_initializer())

opti=tf.train.GradientDescentOptimizer(0.5).minimize(F)

# 打印前2次结果
for i in range(2):
    session.run(opti)
    print(session.run(x))
[[[[ 0.]
   [ 2.]
   [ 9.]
   [ 3.]]

  [[ 4.]
   [ 6.]
   [ 7.]
   [ 0.]]

  [[ 0.]
   [13.]
   [ 1.]
   [ 5.]]

  [[12.]
   [18.]
   [ 0.]
   [14.]]]]
[[[[ 0.]
   [ 2.]
   [ 0.]
   [ 3.]]

  [[ 4.]
   [ 0.]
   [ 7.]
   [ 0.]]

  [[ 0.]
   [13.]
   [ 1.]
   [ 5.]]

  [[12.]
   [ 0.]
   [ 0.]
   [ 0.]]]]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值