@tf.custom_gradient 自定义sign的梯度

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq965194745/article/details/88719838

https://blog.csdn.net/LoseInVain/article/details/83108001
在这里插入图片描述https://github.com/tensorflow/tensorflow/blob/7dd20b844ced19610f8fa67be61d93948563ac43/tensorflow/python/ops/custom_gradient.py

输入

import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
@tf.custom_gradient
def DoublySign(x):
    def grad(dy):
        '''
        dy 是从反向而言的上一层的梯度
        '''
        cond = (x >= -1) & (x <= 1)
        zeros = tf.zeros_like(dy)
        return tf.where(cond,dy,zeros) 
    # 这里相当于是在自己手动计算梯度,如果在-1和1之间,将sign的函数梯度修改为1,根据链式法则,梯度为dy*1=dy
    # 而其他情况下,梯度为dy*0=0
    # tf.where 和 tf.cond 的区别 https://blog.csdn.net/xiadimichen14908/article/details/83592282
    return tf.sign(x),grad


x = tf.constant(np.linspace(-2,2,100))
y = DoublySign(x)
grad = tf.gradients(y,x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(x.eval())
    print(y.eval())
    print(sess.run(grad))
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    ax = plt.gca()         # get current axis 获得坐标轴对象
    plt.xlabel('x')
    plt.ylabel('sign(x)')                                #设置坐标轴的文字标签
    ax.spines['right'].set_color('none')  
    ax.spines['top'].set_color('none')         # 将右边 上边的两条边颜色设置为空 其实就相当于抹掉这两条边

    ax.xaxis.set_ticks_position('bottom')   
    ax.yaxis.set_ticks_position('left')          # 指定下边的边作为 x 轴   指定左边的边为 y 轴

    ax.spines['bottom'].set_position(('data', 0))   #指定 data  设置的bottom(也就是指定的x轴)绑定到y轴的0这个点上
    ax.spines['left'].set_position(('data', 0))
    plt.plot(x.eval(),y.eval())
    plt.title("sign")
    plt.subplot(1,2,2)
    plt.xlabel('x')
    plt.ylabel('d(sign(x))')                                #设置坐标轴的文字标签
    plt.plot(x.eval(),sess.run(grad)[0])
    plt.title("sign grad")

输出

[-2.         -1.95959596 -1.91919192 -1.87878788 -1.83838384 -1.7979798
 -1.75757576 -1.71717172 -1.67676768 -1.63636364 -1.5959596  -1.55555556
 -1.51515152 -1.47474747 -1.43434343 -1.39393939 -1.35353535 -1.31313131
 -1.27272727 -1.23232323 -1.19191919 -1.15151515 -1.11111111 -1.07070707
 -1.03030303 -0.98989899 -0.94949495 -0.90909091 -0.86868687 -0.82828283
 -0.78787879 -0.74747475 -0.70707071 -0.66666667 -0.62626263 -0.58585859
 -0.54545455 -0.50505051 -0.46464646 -0.42424242 -0.38383838 -0.34343434
 -0.3030303  -0.26262626 -0.22222222 -0.18181818 -0.14141414 -0.1010101
 -0.06060606 -0.02020202  0.02020202  0.06060606  0.1010101   0.14141414
  0.18181818  0.22222222  0.26262626  0.3030303   0.34343434  0.38383838
  0.42424242  0.46464646  0.50505051  0.54545455  0.58585859  0.62626263
  0.66666667  0.70707071  0.74747475  0.78787879  0.82828283  0.86868687
  0.90909091  0.94949495  0.98989899  1.03030303  1.07070707  1.11111111
  1.15151515  1.19191919  1.23232323  1.27272727  1.31313131  1.35353535
  1.39393939  1.43434343  1.47474747  1.51515152  1.55555556  1.5959596
  1.63636364  1.67676768  1.71717172  1.75757576  1.7979798   1.83838384
  1.87878788  1.91919192  1.95959596  2.        ]
[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]
[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]

在这里插入图片描述

测试两个参数的写法

import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
@tf.custom_gradient
def DoublySign(x1,x2):
    def grad(dy1,dy2):
        '''
        dy 是从反向而言的上一层的梯度
        '''
        cond = (x1 >= -1) & (x1 <= 1)
        zeros = tf.zeros_like(dy1)
        return tf.where(cond,dy1,zeros),dy2
    # 这里相当于是在自己手动计算梯度,如果在-1和1之间,将sign的函数梯度修改为1,根据链式法则,梯度为dy*1=dy
    # 而其他情况下,梯度为dy*0=0
    # tf.where 和 tf.cond 的区别 https://blog.csdn.net/xiadimichen14908/article/details/83592282
    return (tf.sign(x1),x2),grad


x1 = tf.constant(np.linspace(-2,2,100))
x2 = tf.constant(np.linspace(-2,2,100))
y1,y2 = DoublySign(x1,x2)
grad1 = tf.gradients(y1,x1)
grad2 = tf.gradients(y2,x2)
init = tf.global_variables_initializer()
with tf.Session() as sess:
#     sess.run(init)
#     print(x1.eval())
#     print(x2.eval())
#     print(y1.eval())
#     print(y2.eval())
#     print(sess.run(grad1))
#     print(sess.run(grad2))
    plt.figure(figsize=(10,5))
    plt.subplot(2,2,1)
    ax = plt.gca()         # get current axis 获得坐标轴对象
    plt.xlabel('x')
    plt.ylabel('sign(x)')                                #设置坐标轴的文字标签
    ax.spines['right'].set_color('none')  
    ax.spines['top'].set_color('none')         # 将右边 上边的两条边颜色设置为空 其实就相当于抹掉这两条边

    ax.xaxis.set_ticks_position('bottom')   
    ax.yaxis.set_ticks_position('left')          # 指定下边的边作为 x 轴   指定左边的边为 y 轴

    ax.spines['bottom'].set_position(('data', 0))   #指定 data  设置的bottom(也就是指定的x轴)绑定到y轴的0这个点上
    ax.spines['left'].set_position(('data', 0))
    plt.plot(x1.eval(),y1.eval())
    plt.title("sign")
    plt.subplot(2,2,2)
    plt.xlabel('x')
    plt.ylabel('d(sign(x))')                                #设置坐标轴的文字标签
    plt.plot(x1.eval(),sess.run(grad1)[0])
    plt.title("sign grad")
    
    plt.subplot(2,2,3)
    ax = plt.gca()         # get current axis 获得坐标轴对象
    plt.xlabel('x')
    plt.ylabel('sign(x)')                                #设置坐标轴的文字标签
    ax.spines['right'].set_color('none')  
    ax.spines['top'].set_color('none')         # 将右边 上边的两条边颜色设置为空 其实就相当于抹掉这两条边

    ax.xaxis.set_ticks_position('bottom')   
    ax.yaxis.set_ticks_position('left')          # 指定下边的边作为 x 轴   指定左边的边为 y 轴

    ax.spines['bottom'].set_position(('data', 0))   #指定 data  设置的bottom(也就是指定的x轴)绑定到y轴的0这个点上
    ax.spines['left'].set_position(('data', 0))
    plt.plot(x2.eval(),y2.eval())
    plt.title("sign")
    plt.subplot(2,2,4)
    plt.xlabel('x')
    plt.ylabel('d(sign(x))')                                #设置坐标轴的文字标签
    plt.plot(x2.eval(),sess.run(grad2)[0])
    plt.title("sign grad")

在这里插入图片描述

展开阅读全文

Scrolling Sign

11-01

Problem DescriptionnElectric scrolling signs are often used for advertising. A given sign displays exactly k characters. When the sign is switched on, all of the character positions are initially empty (showing spaces). In each time interval, all of the characters on the sign are shifted to the left by one position, and a new character is added at the right-most position. The character that was in the left-most position moves off the sign.nnFor certain sequences of words, it is possible to reuse characters from one word to form a subsequent word. For example, on a sign with three character positions, the sign can display the message CAT ATE TED by scrolling in the five characters CATED.nnThe advertiser has a specific message to show using the sign. The faster the message is displayed, the more people will be able to see the whole message. Therefore, your job is to find a way to display all the words of the message by scrolling in the smallest number of letters. In between showing the words of the message, the sign may display other words that are not considered part of the message. However, the words of the message must be shown in the order in which they are given.n nnInputnThe first line of test chunk contains a single integer n, the number of test cases in this chunk to follow. Each test case starts with a line containing a two integers, k, the number of character positions on the sign, and w, the number of words in the message. Each of the two integers is between 1 and 100, inclusive. The following w lines each contain a word of the message comprising exactly k uppercase letters.nnPlease process to the end of the data file.n nnOutputnFor each test case, output a line containing a single integer, the minimum number of letters that must be scrolled into the sign so that it displays all the words of the message.n nnSample Inputn2n3 2nCATnTEDn3 3nCATnATEnTEAn2n3 2nCATnTEDn3 3nCATnATEnTEAn nnSample Outputn5n5n5n5 问答

没有更多推荐了,返回首页