最近由于工作需要,keras提供的fit函数不能满足需求,而自己去写一个自定义的fit函数后进行测试时,砍掉加入的其他组件,只对自己的fit函数功能进行测试,却始终不能得到与keras原生fit函数一样的结果,因此不得已看源码学习keras是如何实现的。
以下内容将十分琐碎,有空进行整理,仅记录对我有用的东西,开始吧。
输入数据
- 输入数据会被转换成dataset对象
中间结果:
- 使用MetricsAggregator对象记录中间每个batch的结果,主要有两项:loss、metric,其中loss值是累计的源码如下:
-
metric也是累计的,只是与loss不同的是,loss是被返回后由aggregator来累计,而metric则是由model中的metric对象来完成。
原因:
每次epoch开始时需要进行reset_metrics()操作。
具体:在追踪过程中,进入metric对象中发现:而这两个值一除就是当前输出的accuracy值:
- 单次batch执行:
batch_outs = execution_function(iterator)
这就是执行单个batch训练的函数,很简单,得到的batch_outs内容如下:
这里一定要注意:loss没有累计,而metric均是累计的结果,因此loss后续自己进行处理。
- epoch结束,终结结果汇总:
- fit中各项metric(loss也是其中一种)的输出很有意思,虽然最后输出的都是累计值,但是其累计的地方有差异。
accuracy等非loss值的累计:是传入的指标函数中进行,即在compile函数中传入的metric参数指定的内容,keras的fit函数会自动在这些指标函数上加一个meanwrapper包装,
这个wrapper实现了指标值累计的功能。
loss值的累计:loss值的累计在两个地方,一个是上文提到的aggregator变量负责,而另一个则是进度条对象中,即这个类:(ProgbarLogger类,这是一个Callback的子类),在这个类中内置value变量负责记录历史结果,输出时进行求平均,
,另外Progbar这个进度条类被我改吧改吧后收藏了,非常棒的一个工具箱,比tqdm好看多了。
为什么两个地方进行累计loss:原因是aggregator虽然累计了历史结果,但只是在整个epoch过程完成后进行汇总,得到整个epoch的整体结果,这个是数值上的,即aggregator的目标是缓存历史结果,但只在整个过程结束后反馈给用户。而progbar则是每运行一次则更新一次结果,时时刻刻在进度条上显示出来,偏重可视化的效果。
另外aggregator和progbar是两个不同的工具,并不一定非得在一起用,只是在fit中是组合在一起了,虽然功能有重叠,但不妨碍各自独立的功能,即可以各用个的,不妨碍对方。