TF2 build-in Keras在eager及非eager模式下callback训练过程中梯度的方式

Class Activation Map / Gradient Attention Map

分类/分割任务中可能会需要对训练过程中某些层的计算梯度进行操作,对于Keras来说我们可以通过使用Callback()实现返回梯度的目的,具体的例子如下所示,分为非eager模式和eager模式两部分。

1. 非eager模式

tf.compat.v1.disable_eager_execution()  # 这句一定要加上!

def get_gradient_func(model):
    ## If using 'tf.compat.v1.disable_eager_execution()'
    # 博主使用的模型有两个输出所以此处定义了两个gradients及function
    grads_1 = K.gradients(model.outputs[0], model.inputs[0])
    grads_2 = K.gradients(model.outputs[1], model.inputs[0])
    inputs = model._feed_inputs + model._feed_targets + model._feed_sample_weights
    func_1 = K.function(inputs, grads_1)
    func_2 = K.function(inputs, grads_2)
    return func_lipid, func_calcium

class CustomCallback(Callback):
    def __init__(self,model,training_generator,save_grad_path):
        self.model = model
        self.training_generator = training_generator
        self.save_grad_path = save_grad_path

    def on_epoch_end(self, epoch, logs=None):
        if (epoch+1)%10==0:
            epoch_gradient_1 = []
            epoch_gradient_2 = []
            get_gradient_1, get_gradient_2 = get_gradient_func(self.model)
            for step,batch in enumerate(self.training_generator):
                batch = tuple(t for t in batch)
                train_img = batch[0]
                train_label = batch[1]
                grads_1 = get_gradient_1([train_img, train_label, np.ones(16)]) # batchSize=16
                grads_2 = get_gradient_2([train_img, train_label, np.ones(16)])
                # 存储每个epoch output对input的梯度,下一个epoch时epoch_gradient变量会清空
                epoch_gradient_1.append(grads_1[0][:,:,:,3])
                epoch_gradient_2.append(grads_1[0][:,:,:,3])
           
        else:
            pass

2. eager模式

class CustomCallback(Callback):
    def __init__(self,model,training_generator,save_grad_path):
        self.model = model
        self.training_generator = training_generator
        self.save_grad_path = save_grad_path

    def on_epoch_end(self, epoch, logs=None):
        if (epoch+1)%10==0:
            epoch_gradient_1 = []
            epoch_gradient_2 = []
            input_layer = self.model.get_layer("data") # 模型的输入层'data',也可以是其他名字,根据model各层的起名来定
            # 由于是计算output对input的梯度,所以定义一个临时的模型用来进行out,data这两个tensor的输出
            # 若想计算Output关于其他层的梯度,只需要将input_layer.output替换为其他层的output即可
            temp_model = Model([self.model.inputs],[self.model.output,input_layer.output])
            for step,batch in enumerate(self.training_generator):
                batch = tuple(t for t in batch)
                train_img = batch[0]
                train_label = batch[1]
                # 默认的non-persisitent模式下,with tf.GradientTape() as gtape:一次只能使用gtape.gradient一次,连续使用会报错
                with tf.GradientTape() as gtape:
                    out, data = temp_model(train_img)
                    # 由于gtape只能跟踪trainable variants,而model的input是一个non-trainable的变量,所以要使用gtape.watch()进行追踪
                    gtape.watch(data) 
                grads_1 = gtape.gradient(out[0], data)
                with tf.GradientTape() as gtape:
                    out, data = temp_model(train_img)
                    gtape.watch(data)
                grads_2 = gtape.gradient(out[1], data)
                epoch_gradient_1.append(grads_1[:,:,:,3])
                epoch_gradient_2.append(grads_2[:,:,:,3])

        else:
            pass

个人推荐使用eager模式。

References:

https://stackoverflow.com/questions/58322147/how-to-generate-cnn-heatmaps-using-built-in-keras-in-tf2-0-tf-keras

https://stackoverflow.com/questions/61568665/tf2-compute-gradients-in-keras-callback-in-non-eager-mode
https://discuss.pytorch.org/t/generating-the-class-activation-maps/42887
https://www.tensorflow.org/api_docs/python/tf/GradientTape#gradient

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值