tensorflow2.0-keras: fit 源码初探1【2020-1-22 09:39:47】

最近由于工作需要,keras提供的fit函数不能满足需求,而自己去写一个自定义的fit函数后进行测试时,砍掉加入的其他组件,只对自己的fit函数功能进行测试,却始终不能得到与keras原生fit函数一样的结果,因此不得已看源码学习keras是如何实现的。

以下内容将十分琐碎,有空进行整理,仅记录对我有用的东西,开始吧。

 

输入数据

  1. 输入数据会被转换成dataset对象

 

中间结果:

  1. 使用MetricsAggregator对象记录中间每个batch的结果,主要有两项:loss、metric,其中loss值是累计的源码如下:

  2. metric也是累计的,只是与loss不同的是,loss是被返回后由aggregator来累计,而metric则是由model中的metric对象来完成。
    原因:
            每次epoch开始时需要进行reset_metrics()操作。
            具体:在追踪过程中,进入metric对象中发现:

    而这两个值一除就是当前输出的accuracy值:

     

  3. 单次batch执行:
    batch_outs = execution_function(iterator)

    这就是执行单个batch训练的函数,很简单,得到的batch_outs内容如下: 

     这里一定要注意:loss没有累计,而metric均是累计的结果,因此loss后续自己进行处理。

  4. epoch结束,终结结果汇总:

     

  5. fit中各项metric(loss也是其中一种)的输出很有意思,虽然最后输出的都是累计值,但是其累计的地方有差异。
    accuracy等非loss值的累计:是传入的指标函数中进行,即在compile函数中传入的metric参数指定的内容,keras的fit函数会自动在这些指标函数上加一个meanwrapper包装,

    这个wrapper实现了指标值累计的功能。
    loss值的累计:loss值的累计在两个地方,一个是上文提到的aggregator变量负责,而另一个则是进度条对象中,即这个类:(ProgbarLogger类,这是一个Callback的子类),在这个类中内置value变量负责记录历史结果,输出时进行求平均,

    ,另外Progbar这个进度条类被我改吧改吧后收藏了,非常棒的一个工具箱,比tqdm好看多了。
    为什么两个地方进行累计loss:原因是aggregator虽然累计了历史结果,但只是在整个epoch过程完成后进行汇总,得到整个epoch的整体结果,这个是数值上的,即aggregator的目标是缓存历史结果,但只在整个过程结束后反馈给用户。而progbar则是每运行一次则更新一次结果,时时刻刻在进度条上显示出来,偏重可视化的效果。
    另外aggregator和progbar是两个不同的工具,并不一定非得在一起用,只是在fit中是组合在一起了,虽然功能有重叠,但不妨碍各自独立的功能,即可以各用个的,不妨碍对方。
  6.  
  7.  

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值