在所有可取的<C, g>值对下,grid.py 都训练出一个分类器,并计算该分类器的准确率,将准确率最高的分类器所对应的<C, g>作为最优的参数。这是通过交叉验证来实现的:将训练集分为 N 折,其中的 N-1 折用来训练分类器,剩下的 1 折用来计算分类准确率,从而计算 N 次实验的平均准确率作为给定<C, g>值对下的分类器的准确率。
----------------------------------------------------------------------------------------------------------------------------------------
全局设置:
# svmtrain and gnuplot executable
is_win32 = (sys.platform == 'win32')
if not is_win32:
svmtrain_exe = "../svm-train"
gnuplot_exe = "/usr/bin/gnuplot"
else:
# example for windows
svmtrain_exe = r"..\windows\svm-train.exe"
# windows平台需要修改程序 gnuplot 的路径
gnuplot_exe = r"c:\Program Files (x86)\gnuplot\bin\pgnuplot.exe"
# global parameters and their default values
# 全局变量,同时设置默认的参数
# 5-折交叉验证
fold = 5
# 参数 c和g 的范围 begin~end 以及步长 step
c_begin, c_end, c_step = -5, 15, 2
g_begin, g_end, g_step = 3, -15, -2
# dataset_pathname:训练数据路径,dataset_title:训练数据文件名,pass_through_string:用户给定的 svm 分类器的参数
global dataset_pathname, dataset_title, pass_through_string
# out_filename:文件名,该文件记录交叉验证过程中计算得到的每组数据(c,g,rate等),png_filename:图片的文件名,将 out_filename 中的数据画成二维图形
global out_filename, png_filename
# experimental
# 设置 telnet 和 ssh 的节点名称--- grid.py 的网格搜索可以并行化
telnet_workers = []
ssh_workers = []
# 本地节点数量
nr_local_worker = 1
----------------------------------------------------------------------------------------------------------------------------------------
处理命令行传入的参数:
# process command line options, set global parameters
def process_options(argv=sys.argv):
# 将这些变量设为全局的,免去函数间的传参
global fold
global c_begin, c_end, c_step
global g_begin, g_end, g_step
global dataset_pathname, dataset_title, pass_through_string
# gnuplot 是画图的句柄,要随时将得到的每组数据(c,g,rate等)画在图上
global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
# 命令提示
usage = """\
Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold]
[-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
[additional parameters for svm-train] dataset"""
# 命令使用错误提示
if len(argv) < 2:
print(usage)
sys.exit(1)
# 变量赋值
dataset_pathname = argv[-1]
dataset_title = os.path.split(dataset_pathname)[1]