Tensorflow自定义层踩的大坑

Tensorflow自定义层踩的大坑

tensorflow自定义层踩的大坑(调了两天的BUG)


前言

由于最近在做注意力相关的内容,所以用到了自定义层。在tensorflow的框架下进行l模型训练没问题,但是加载模型出现了各种各样的问题,给大家讲一下,希望大家别和我一样调两天的BUG.


一、自定义层的__init__中的参数问题

class BMW(layers.Layer):
    def __init__(self, channels, c2=None, factor=32,name=None,**kwargs):
        super(EMA, self).__init__(name=name,**kwargs)
        self.channels=channels
        self.groups = factor
        assert channels // self.groups > 0

以上式为例,我init中有两个需要一个需要传递的参数,就是channel,这时侯,必须要在init里进行 self.channels=channels的这种定义。否则python会出现一些报错,说是缺失channel。这个定义一定要记住!!!!!!!

二、在CLASS类中的最后一个函数定义get_cofig函数

上方的代码将channel进行了self的定义,需要在最后定义个get_config函数,而且要加上在config.update中加上"channels": self.channels。再说一个非常重要的事情,比如你在init里还定义了一些卷积层和池化层等等,这个更新的函数里,一定不要更新这些东西,更新了这东西,相信我,会后悔的。这东西折磨我两天啊!!!!!

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            "channels":  self.channels
        })
        return config
  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值