Auto-Keras API详解(3)——Supervised类
(一)前 言
Supervised类是Auto-Keras中用于所有监督学习任务的基类,这一节我们将详细介绍它的主要方法和各项参数的意义。
(二)方法详解
(1)fit方法
fit方法用于寻找最优的网络结构并且加以训练,这个函数会基于给定的数据集,为该数据集找到最佳的神经网络结构,数据集的格式为Numpy数据型,训练数据需要通过x_train,y_train传递。
- 参数列表:
- x:
一个Numpy数组的实例,包含了训练数据或者是训练数据与验证数据结合的数据 - y:
一个Numpy数组的实例,包含了训练数据的标签或者是训练标签与验证标签结合的数据 - x_test:
一个Numpy数组,包含了测试数据 - y_test :
一个Numpy数组,包含了测试数据的标签 - time_limit:
搜索网络的时间限制
- x:
(2)final_fit方法
final_fit方法用于找到最优网络后做最后的训练。
- 参数列表:
- x:
一个Numpy数组的实例,包含了训练数据或者是训练数据与验证数据结合的数据 - y:
一个Numpy数组的实例,包含了训练数据的标签或者是训练标签与验证标签结合的数据 - x_test:
一个Numpy数组,包含了测试数据 - y_test :
一个Numpy数组,包含了测试数据的标签 - trainer_args:
一个包含了ModelTrainer结构参数的字典 - retrain:
一个布尔值,用来决定是否重新初始化模型的权重参数
- x:
(3)predict方法
predict方法用来测试数据的预测值。
- 参数列表:
- x_test:
一个Numpy数组,包含了测试数据
- x_test:
(4)evaluate方法
evaluate方法用来在预测值和实际值之间评估模型的精度。
- 参数列表:
- x_test:
一个Numpy数组,包含了测试数据 - y_test :
一个Numpy数组,包含了测试数据的标签
- x_test:
(三)总 结
在这一节中,我们介绍了Supervised类的相关方法,有任何的问题请在评论区留言,我会尽快回复,谢谢支持!