预备
- NTU(TW) Chih-Chung Chang and Chih-Jen Lin LIBSVM
- LIBSVM Data: Classification, Regression, and Multi-label
正文
a 编译libsvm
uname@hname
:
~/libsvm
$ make
g++ -Wall -Wconversion -O3 -fPIC -c svm.cpp
g++ -Wall -Wconversion -O3 -fPIC svm-train.c svm.o -o svm-train -lm
g++ -Wall -Wconversion -O3 -fPIC svm-predict.c svm.o -o svm-predict -lm
g++ -Wall -Wconversion -O3 -fPIC svm-scale.c -o svm-scale
g++ -Wall -Wconversion -O3 -fPIC -c svm.cpp
g++ -Wall -Wconversion -O3 -fPIC svm-train.c svm.o -o svm-train -lm
g++ -Wall -Wconversion -O3 -fPIC svm-predict.c svm.o -o svm-predict -lm
g++ -Wall -Wconversion -O3 -fPIC svm-scale.c -o svm-scale
b svmguide1 (类别平衡数据)
代码wrap.py, 详见附录.
$ python3 wrap.py --kernel rbf --C 1
Accuracy = 66.925% (2677/4000) (classification)
$ python3 wrap.py --kernel rbf --C 1 --scale
Accuracy = 96.15% (3846/4000) (classification)
$ python3 wrap.py --kernel linear --C 1
Accuracy = 95.675% (3827/4000) (classification)
$ python3 wrap.py --kernel linear --C 1 --scale
Accuracy = 95.675% (3827/4000) (classification)
$ python3 wrap.py --kernel rbf --C 1000
Accuracy = 70.475% (2819/4000) (classification)
$ python3 wrap.py --kernel rbf --C 1000 --scale
Accuracy = 96.725% (3869/4000) (classification)
$ python3 easy.py ../../data/svmguide1.tr ../../data/svmguide1.te
Best c=2.0, g=2.0 CV rate=96.9893
Accuracy = 96.875% (3875/4000) (classification)
结论
- rbf核svm对数据规范化敏感, 线性核svm对数据规范化不敏感.
- 可以通过调整超参数(正则化参数)提高(rbf核)svm的性能.
- 调参(搜索最优参数)需要网格搜索(grid search), 代价是巨大的, 训练测试分类器的次数关于需要调的参数的个数是指数级增长的.
注意easy.py需要修改一行代码
52c52
< cmd = '{0} -svmtrain "{1}" -gnuplot "{2}" "{3}"' ...
---
> cmd = 'python3 {0} -svmtrain "{1}" -gnuplot "{2}" "{3}"' ...
否则需要执行$ chmod +777 /path/to/grid.py
.
c shuttle 类别不平衡数据
代码peek.py, 详见附录.
dataset | label count | note |
---|---|---|
svmguide1 | {1: 4000, 0: 3089} | balance |
shuttle | {1: 35033, 4: 6906, 5: 2525, 3: 133, 2: 37, 7: 11, 6: 6} | imbalance |
代码wrap.py, 详见附录.
$ python3 wrap.py --dataset shuttle
Accuracy = 98.0017% (1128/1151) (classification)
$ python3 wrap.py --dataset shuttle --weight_class
Accuracy = 85.0565% (979/1151) (classification)
根据{1: 35033, 4: 6906, 5: 2525, 3: 133, 2: 37, 7: 11, 6: 6}, 输入一组参数
-w4 5.0728352157544165 (== 35033 / 6906)
-w5 13.874455445544555 (== 35033 / 2525)
-w3 263.406015037594 (== 35033 / 133)
-w2 946.8378378378378 (== 35033 / 37)
-w7 3184.818181818182 (== 35033 / 11)
-w6 5838.833333333333 (== 35033 / 6)
这一组参数十分夸张, 准确率大幅下降了. 但是如果我们仔细观察预测输出和真实标签, 我们就发现模型在较多的类上错误增加, 但是在较少的类上错误减少.
代码tmp.py, 详见附录.
argument | error |
---|---|
with weight_class | {1: 167, 4: 5} |
without weight_class | {1: 9, 4: 13, 3: 1} |
启发式的输入另一组参数
-w4 1.025
-w3 1.005
好吧, 只多正确分类一个🙃. 这个类别不平衡数据集非常特殊, 不赋予较少的类别较大的权重时分类器已经能够达到很高的性能, 因此不能体现-wi
参数的作用.
argument | error |
---|---|
with weight_class | {1: 9, 4: 12, 3: 1} |
without weight_class | {1: 9, 4: 13, 3: 1} |
附录
peek.py
from typing import List
from collections import Counter
def collect_lbl(path:str, lbls:List):
with open(path, mode='rt') as f:
for line in f.readlines():
lbl = int(line.split()[0])
lbls.append(lbl)
return
svmguide1_lbls = []
shuttle_lbls = []
collect_lbl('./data/svmguide1.tr', svmguide1_lbls)
collect_lbl('./data/svmguide1.te', svmguide1_lbls)
collect_lbl('./data/shuttle.tr', shuttle_lbls)
collect_lbl('./data/shuttle.te', shuttle_lbls)
svmguide1_cnt = Counter(svmguide1_lbls)
shuttle_cnt = Counter(shuttle_lbls)
print('svmguide1', svmguide1_cnt)
print('shuttle', shuttle_cnt)
wrap.py
import os
import sys
import subprocess
from pathlib import Path
from argparse import ArgumentParser
workspace_dir = Path.cwd()
libsvm_dir = workspace_dir / 'libsvm'
svm_pr = libsvm_dir / 'svm-scale'
svm_tr = libsvm_dir / 'svm-train'
svm_te = libsvm_dir / 'svm-predict'
data_dir = workspace_dir / 'data'
shuttle_tr = data_dir / 'shuttle.tr'
shuttle_te = data_dir / 'shuttle.te'
shuttle_mdl = data_dir / 'shuttle.mdl'
shuttle_out = data_dir / 'shuttle.out'
svmguide1_tr = data_dir / 'svmguide1.tr'
svmguide1_te = data_dir / 'svmguide1.te'
svmguide1_mdl = data_dir / 'svmguide1.mdl'
svmguide1_out = data_dir / 'svmguide1.out'
svmguide1_pr = data_dir / 'svmguide1.pr'
svmguide1_tr_tmp = Path('/','tmp','svmguide1.tr.tmp')
svmguide1_te_tmp = Path('/','tmp','svmguide1.te.tmp')
def check_ready_file():
ready_file = True
ready_file = ready_file and workspace_dir.exists()
ready_file = ready_file and libsvm_dir.exists()
ready_file = ready_file and svm_pr.exists()
ready_file = ready_file and svm_tr.exists()
ready_file = ready_file and svm_te.exists()
ready_file = ready_file and data_dir.exists()
ready_file = ready_file and shuttle_tr.exists()
ready_file = ready_file and shuttle_te.exists()
ready_file = ready_file and svmguide1_tr.exists()
ready_file = ready_file and svmguide1_te.exists()
if not ready_file:
print(
'[ svm.pr svm.tr svm.te ]'
f' should be in {os.getcwd()}/libsvm/\n'
'[ svmguide1.tr svmguide1.te shuttle.tr shuttle.te ]'
f' should be in {os.getcwd()}/data/'
)
exit(-1)
return
def run_svmguide1(args):
if args.scale:
print('>>> scale', svmguide1_tr.name)
subprocess.call(
args=f'{svm_pr} -s {svmguide1_pr}'
f' {svmguide1_tr} > {svmguide1_tr_tmp}',
shell=True,
)
if args.scale:
print('>>> train', svmguide1_tr_tmp.name)
subprocess.call(
args=f'{svm_tr}'
f' -t {kernel_to_t[args.kernel]} -c {args.C}'
f' {svmguide1_tr_tmp} {svmguide1_mdl}',
shell=True,
)
else:
print('>>> train', svmguide1_tr.name)
subprocess.call(
args=f'{svm_tr}'
f' -t {kernel_to_t[args.kernel]} -c {args.C}'
f' {svmguide1_tr} {svmguide1_mdl}',
shell=True,
)
if args.scale:
print('>>> scale', svmguide1_te.name)
subprocess.call(
args=f'{svm_pr} -r {svmguide1_pr}'
f' {svmguide1_te} > {svmguide1_te_tmp}',
shell=True,
)
if args.scale:
print('>>> test', svmguide1_te_tmp.name)
subprocess.call(
args=f'{svm_te}'
f' {svmguide1_te_tmp} {svmguide1_mdl} {svmguide1_out}',
shell=True,
)
else:
print('>>> test', svmguide1_te.name)
subprocess.call(
args=f'{svm_te}'
f' {svmguide1_te} {svmguide1_mdl} {svmguide1_out}',
shell=True,
)
return
def run_shuttle(args):
if args.weight_class:
print('>>> train', shuttle_tr.name)
subprocess.call(
args=f'{svm_tr}'
f' -t {kernel_to_t[args.kernel]} -c {args.C}'
' -w4 1.025'
' -w3 1.005'
# ' -w4 5.0728352157544165'
# ' -w5 13.874455445544555'
# ' -w3 263.406015037594'
# ' -w2 946.8378378378378'
# ' -w7 3184.818181818182'
# ' -w6 5838.833333333333'
f' {shuttle_tr} {shuttle_mdl}',
shell=True,
)
else:
print('>>> train', shuttle_tr.name)
subprocess.call(
args=f'{svm_tr}'
f' -t {kernel_to_t[args.kernel]} -c {args.C}'
f' {shuttle_tr} {shuttle_mdl}',
shell=True,
)
print('>>> test', shuttle_te.name)
subprocess.call(
args=f'{svm_te}'
f' {shuttle_te} {shuttle_mdl} {shuttle_out}',
shell=True,
)
return
parser = ArgumentParser()
parser.add_argument('--dataset', type=str,
choices=['svmguide1','shuttle'], default='svmguide1')
parser.add_argument('--kernel', type=str,
choices=['rbf','linear'], default='rbf')
parser.add_argument('--C', type=float, default=1)
parser.add_argument('--scale', action='store_true', default=False)
parser.add_argument('--weight_class', action='store_true', default=False)
args = parser.parse_args()
kernel_to_t = { 'rbf':2, 'linear':0, }
print('>>> workspace', workspace_dir)
print('>>> args', args)
exec(f'run_{args.dataset}(args)')
tmp.py
from typing import List
from collections import Counter
def collect_lbl(path:str, lbls:List):
with open(path, mode='rt') as f:
for line in f.readlines():
lbl = int(line.split()[0])
lbls.append(lbl)
return
real_lbls = []
weight_class_lbls = []
no_weight_class_lbls = []
collect_lbl('./shuttle.te', real_lbls)
collect_lbl('./shuttle.out.weight_class', weight_class_lbls)
collect_lbl('./shuttle.out.no_weight_class', no_weight_class_lbls)
weight_class_err = {}
no_weight_class_err = {}
for i in range(len(real_lbls)):
real_lbl = real_lbls[i]
weight_class_lbl = weight_class_lbls[i]
no_weight_class_lbl = no_weight_class_lbls[i]
if weight_class_lbl != real_lbl:
if real_lbl in weight_class_err.keys():
weight_class_err[real_lbl] += 1
else:
weight_class_err[real_lbl] = 1
if no_weight_class_lbl != real_lbl:
if real_lbl in no_weight_class_err.keys():
no_weight_class_err[real_lbl] += 1
else:
no_weight_class_err[real_lbl] = 1
print('weight_class_err', weight_class_err)
print('no_weight_class_err', no_weight_class_err)