surprise库有一组内建 数据集,但当然可以使用自定义数据集。加载评分数据集可以从文件(例如csv文件)或pandas数据框中完成。无论哪种方式,都需要定义一个Reader
对象来解析文件或数据框。\
要从文件加载数据集(例如csv文件),需要 load_from_file()
方法:
from surprise import BaselineOnly
from surprise import Dataset
from surprise import Reader
from surprise.model_selection import cross_validate
file_path = os.path.expanduser('~/.surprise_data/ml-100k/ml-100k/u.data')#数据集文件所在目录
reader = Reader(line_format='user item rating timestamp', sep='\t')
data = Dataset.load_from_file(file_path, reader=reader)
cross_validate(BaselineOnly(), data, verbose=True) #现在可以使用这个数据集,例如调用cross_validate
从pandas数据框加载数据集,需要使用 load_from_df()
方法。这里不多赘述,可自己查看说明。
使用交叉验证迭代器
对于交叉验证,可以用 cross_validate()
完成所有的工作。但是为了更好地控制,可以实例化交叉验证迭代器,并使用且迭代器的 split()
方法和算法的 test()
方法,对每一折进行预测。
下面是一个栗子,我们使用了一个经典的K-fold交叉验证程序,其中包含数据被分为3份(3折交叉验证):
from surprise import SVD
from surprise import Dataset
from surprise import accuracy
from surprise.model_selection import KFold
data = Dataset.load_builtin('ml-100k') #加载数据集
# define a cross-validation iterator
kf = KFold(n_splits=3) #定义交叉验证迭代器
algo = SVD()
for trainset, testset in kf.split(data):
# 训练并测试算法
algo.fit(trainset)
predictions = algo.test(testset)
# 计算并打印RMSE
accuracy.rmse(predictions, verbose=True)
结果:
RMSE: 0.9374 RMSE: 0.9476 RMSE: 0.9478
也可以使用其他交叉验证迭代器,例如LeaveOneOut或ShuffleSplit。在这里查看所有可用的迭代器。Surprise的交叉验证工具的设计灵感来源于优秀的scikit-learn API。
交叉验证的一个特例是folds已经由某些文件预定义,这里同样查看说明。
使用GridSearchCV调整算法参数
该cross_validate()
函数针对给定的一组交叉验证参数报告过程的准确性度量(如RMSE、MAE这些)。如果你想知道哪个参数组合能够产生最好的结果,那么这个 GridSearchCV
类就可以解决问题。给定一个dict
参数,这个类会尝试所有的参数组合,并报告任何准确性度量(对不同分割进行平均的)的最佳参数。它受到scikit-learn的GridSearchCV的启发。
接下来这个例子我们尝试了SVD算法的参数 n_epochs
, lr_all
和 reg_all
的不同值。
from surprise import SVD
from surprise import Dataset
from surprise.model_selection import GridSearchCV
data = Dataset.load_builtin('ml-100k')
param_grid = {'n_epochs': [5, 10], 'lr_all': [0.002, 0.005],
'reg_all': [0.4, 0.6]}
gs = GridSearchCV(SVD, param_grid, measures=['rmse', 'mae'], cv=3)
gs.fit(data)
# best RMSE score
print(gs.best_score['rmse'])
# combination of parameters that gave the best RMSE score
print(gs.best_params['rmse'])
结果:
0.961300130118 {'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.4}
我们在这里评估3倍交叉验证过程的平均RMSE和MAE,但可以使用任何交叉验证迭代器。
一旦fit()
被调用, best_estimator
这个属性给了我们一个算法实例最优的一组参数,可以根据我们的喜好使用它:
# 可以使用产生最优RMSE的算法
algo = gs.best_estimator['rmse']
algo.fit(data.build_full_trainset())
注意:字典参数,例如bsl_options
与sim_options
需要特殊对待。请参阅以下使用示例:
param_grid = {'k': [10, 20],
'sim_options': {'name': ['msd', 'cosine'],
'min_support': [1, 5],
'user_based': [False]}
}
当然,两者可以结合使用,例如 KNNBaseline
:
param_grid = {'bsl_options': {'method': ['als', 'sgd'],
'reg': [1, 2]},
'k': [2, 3],
'sim_options': {'name': ['msd', 'cosine'],
'min_support': [1, 5],
'user_based': [False]}
}
为了进一步分析,cv_results
属性具有所有需要的信息,并且可以在pandas数据框中导入:
results_df = pd.DataFrame.from_dict(gs.cv_results)
在我们的例子中,该cv_results
属性看起来像这样(float格式):
'split0_test_rmse' : [ 1.0 , 1.0 , 0.97 , 0.98 , 0.98 , 0.99 , 0.96 , 0.97 ] 'split1_test_rmse' : [ 1.0 , 1.0 , 0.97 , 0.98 , 0.98 , 0.99 , 0.96 , 0.97 ] 'split2_test_rmse' : [ 1.0 , 1.0 , 0.97 , 0.98 , 0.98 , 0.99 , 0.96 , 0.97 ] 'mean_test_rmse' : [ 1.0 , 1.0 , 0.97 , 0.98 , 0.98 , 0.99 , 0.96 , 0.97 ] 'std_test_rmse' : [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ] 'rank_test_rmse' : [ 7 8 3 5 4 6 1 2 ] 'split0_test_mae' : [ 0.81 , 0.82 , 0.78 , 0.79 , 0.79 , 0.8 , 0.77 , 0.79 ] 'split1_test_mae' : [ 0.8 , 0.81 , 0.78 , 0.79 , 0.78 , 0.79 , 0.77 , 0.78 ] 'split2_test_mae' : [ 0.81 , 0.81 , 0.78 , 0.79 , 0.78 , 0.8 , 0.77 , 0.78 ] 'mean_test_mae' : [ 0.81 , 0.81 , 0.78 , 0.79 , 0.79 , 0.8 , 0.77 , 0.78 ] 'std_test_mae' : [ 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ] 'rank_test_mae' : [ 7 8 2 5 4 6 1 3 ] 'mean_fit_time' : [ 1.53 , 1.52 , 1.53 , 1.53 , 3.04 , 3.05 , 3.06 , 3.02 ] 'std_fit_time' : [ 0.03 , 0.04 , 0.0 , 0.01 , 0.04 , 0.01 , 0.06 , 0.01 ] 'mean_test_time' : [ 0.46 , 0.45 , 0.44 , 0.44 , 0.47 , 0.49 , 0.46 , 0.34 ] 'std_test_time' : [ 0.0 , 0.01 , 0.01 , 0.0 , 0.03 , 0.06 , 0.01 , 0.08 ] 'PARAMS' : [{ 'n_epochs' : 5 , 'lr_all' : 0.002 , 'reg_all' : 0.4 }, { 'n_epochs' : 5 , 'lr_all' : 0.002 , 'reg_all' : 0.6 }, {'n_epochs' : 5 , 'lr_all' : 0.005 , 'reg_all' : 0.4 }, { 'n_epochs' : 5 , 'lr_all' : 0.005 , 'reg_all' : 0.6 }, { 'n_epochs' : 10 , 'lr_all' : 0.002 , 'reg_all' : 0.4 }, { 'n_epochs' : 10 , 'lr_all' : 0.002 , 'reg_all' : 0。6 }, {'n_epochs' : 10 , 'lr_all' : 0.005 , 'reg_all' : 0.4 }, { 'n_epochs' : 10 , 'lr_all' : 0.4 ,0.6 ,0.4 ,0.6 ]0.005, 'reg_all': 0.6}] 'param_n_epochs': [5, 5, 5, 5, 10, 10, 10, 10] 'param_lr_all': [0.0, 0.0, 0.01, 0.01, 0.0, 0.0, 0.01, 0.01] 'param_reg_all': [0.4, 0.6, 0.4, 0.6,