libsvm 学习笔记(四)--- grid.py 关键代码详解

在所有可取的<C, g>值对下,grid.py 都训练出一个分类器,并计算该分类器的准确率,将准确率最高的分类器所对应的<C, g>作为最优的参数。这是通过交叉验证来实现的:将训练集分为 N 折,其中的 N-1 折用来训练分类器,剩下的 1 折用来计算分类准确率,从而计算 N 次实验的平均准确率作为给定<C, g>值对下的分类器的准确率。


----------------------------------------------------------------------------------------------------------------------------------------

全局设置:


# svmtrain and gnuplot executable


is_win32 = (sys.platform == 'win32')
if not is_win32:
       svmtrain_exe = "../svm-train"
       gnuplot_exe = "/usr/bin/gnuplot"
else:
       # example for windows
       svmtrain_exe = r"..\windows\svm-train.exe"

       # windows平台需要修改程序 gnuplot 的路径
       gnuplot_exe = r"c:\Program Files (x86)\gnuplot\bin\pgnuplot.exe"


# global parameters and their default values

# 全局变量,同时设置默认的参数

# 5-折交叉验证
fold = 5

# 参数 c和g 的范围 begin~end 以及步长 step
c_begin, c_end, c_step = -5,  15, 2
g_begin, g_end, g_step =  3, -15, -2

dataset_pathname:训练数据路径,dataset_title:训练数据文件名,pass_through_string:用户给定的 svm 分类器的参数
global dataset_pathname, dataset_title, pass_through_string
out_filename:文件名,该文件记录交叉验证过程中计算得到的每组数据(c,g,rate等),png_filename:图片的文件名,将 out_filename 中的数据画成二维图形

global out_filename, png_filename


# experimental

# 设置 telnet 和 ssh 的节点名称--- grid.py 的网格搜索可以并行化
telnet_workers = []
ssh_workers = []

# 本地节点数量
nr_local_worker = 1


----------------------------------------------------------------------------------------------------------------------------------------

处理命令行传入的参数:



# process command line options, set global parameters
def process_options(argv=sys.argv):

    # 将这些变量设为全局的,免去函数间的传参
    global fold
    global c_begin, c_end, c_step
    global g_begin, g_end, g_step
    global dataset_pathname, dataset_title, pass_through_string

    gnuplot 是画图的句柄,要随时将得到的每组数据(c,g,rate等)画在图上
    global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename
    

    # 命令提示
    usage = """\
Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] 
[-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname]
[additional parameters for svm-train] dataset"""

    # 命令使用错误提示
    if len(argv) < 2:
        print(usage)
        sys.exit(1)

    # 变量赋值
    dataset_pathname = argv[-1]
    dataset_title = os.path.split(dataset_pathname)[1]

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值