问题描述:
Callback 中 run_context.original_args() 方法能够取得哪些参数,Callback docstring 中丝毫没有提及,还只能教程“自定义Callback”中只能找到一个粗略的描述,更详细的信息(如 train_network 是指 with_loss_network 还是 without_loss_network)则没有。
1. train_network 就是模型自身吧?我可以直接调用它来做checkpoint吗?
2. 我在 callback 中进行 model.eval 需要设置 `model.set_train(false)` 吗?如果需要的话,那我后续训练是不是就无法进行了,因为回去train的时候mode已经被切换成 `predict` 了
PS:省略号表达的意思太含混了
解答:
1. `train_network` 是带有优化器的训练网络。可以调用`save_checkpoint`把`train_network`保存成ckpt文件。
- 想问的是`net.set_train()`吧,`callback`里面的`train_network`就是训练的网络,如果想进行推理,可以train_network.set_train(False),再train_network.set_train(True)。
非常感谢您提出的问题,我们将在文档中补充如下信息供参考。
1. 在RunContext类中说明使用场景和推荐用法,以及与Callback的关系,并链到文档中的已有用例。
2. 在Callback类中列举框架中在各场景下已支持的所有属性,同时说明可支持用户自定义。