支持向量机学习

支持向量机学习资料:

李航的统计学习方法

支持向量机学习博客一

支持向量机学习博客二

核函数学习链接如下:

核函数学习链接

核函数学习链接2

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# SMO的一个简单实现
# implement SMO

import sys
import math
import matplotlib.pyplot as plt

samples = []#训练数据
labels = []#标记结果
class svm_params:
    def __init__(self):
        self.a = []
        self.b = 0
params = svm_params()#支持向量機的參數
e_dict = []
#train_data = "svm.train_mix_ok"
train_data = "svm.train"#訓練數據文件

def loaddata():
    """
    加载数据,初始化支持向量机的参数
    :return:
    """
    fn = open(train_data,"r")
    for line in fn:
        line = line[:-1]
        vlist = line.split("\t")
        #生成训练数据
        samples.append((int(vlist[0]), int(vlist[1])))
        #读取标签数据
        labels.append(int(vlist[2]))
        params.a.append(0.0)
    fn.close()

# linear
#线性核,计算两个点的对应相乘的值
def kernel(j, i):
    """
    线性核,对任意两个点进行线性计算
    :param j:索引为i的训练数据
    :param i:索引为j的训练数据
    :return:两个点的对应位置相乘然后求和的结果
    """
    ret = 0.0
    for idx in range(len(samples[j])):
        ret += samples[j][idx] * samples[i][idx]
    return ret

def predict_real_diff(i):
    diff = 0.0
    for j in range(len(samples)):
        diff += params.a[j] * labels[j] * kernel(j,i)
    diff = diff + params.b - labels[i]
    return diff

def init_e_dict():
    for i in range(len(params.a)):
        e_dict.append(predict_real_diff(i))

def update_e_dict():
    for i in range(len(params.a)):
        e_dict[i] = predict_real_diff(i)

def train(tolerance, times, C):
    """
    :param tolerance:计算精度
    :param times:迭代次数
    :param C:惩罚系数
    :return:返回值
    """
    file=open("log.txt","w",encoding="utf8")
    time = 0
    init_e_dict()#初始化误差函数
    updated = True
    while time < times and updated:
        updated = False
        time += 1
        for i in range(len(params.a)):
            ai = params.a[i]
            Ei = e_dict[i]
            # 违反KKT,判断是否违反KKT条件
            # agaist the KKT
            if (labels[i] * Ei < -tolerance and ai < C) or (labels[i] * Ei > tolerance and ai > 0):
                for j in range(len(params.a)):
                    if j == i: continue
                    eta = kernel(i, i) + kernel(j, j) - 2 * kernel(i, j)
                    if eta <= 0:
                        continue
                    new_aj = params.a[j] + labels[j] * (e_dict[i] - e_dict[j]) / eta 
                    L = 0.0
                    H = 0.0
                    #判断正负号选择
                    if labels[i] == labels[j]:
                        L = max(0, params.a[j] + params.a[i] - C)
                        H = min(C, params.a[j] + params.a[i])
                    else:
                        L = max(0, params.a[j] - params.a[i]) 
                        H = min(C, C + params.a[j] - params.a[i])
                    if new_aj > H:
                        new_aj = H
                    if new_aj < L:
                        new_aj = L
                    # 《统计学习方法》公式7.109(下同)
                    # formula 7.109
                    #更新ai
                    new_ai = params.a[i] + labels[i] * labels[j] * (params.a[j] - new_aj)
                    # 第二个变量下降是否达到最小步长
                    # decline enough for new_aj
                    if abs(params.a[j] - new_aj) < 0.001:
                        print("j = %d, is not moving enough" % j)
                        file.write("j = %d, is not moving enough" % j+"\n")
                        continue
                    # formula 7.115
                    #更新b1
                    new_b1 = params.b - e_dict[i] - labels[i]*kernel(i,i)*(new_ai-params.a[i]) - labels[j]*kernel(j,i)*(new_aj-params.a[j]) 
                    # formula 7.116
                    #更新b2
                    new_b2 = params.b - e_dict[j] - labels[i]*kernel(i,j)*(new_ai-params.a[i]) - labels[j]*kernel(j,j)*(new_aj-params.a[j]) 
                    if new_ai > 0 and new_ai < C: new_b = new_b1
                    elif new_aj > 0 and new_aj < C: new_b = new_b2
                    else: new_b = (new_b1 + new_b2) / 2.0
                    
                    params.a[i] = new_ai
                    params.a[j] = new_aj
                    params.b = new_b
                    update_e_dict()
                    updated = True
                    print("iterate: %d, changepair: i: %d, j:%d" %(time, i, j))
                    file.write("iterate: %d, changepair: i: %d, j:%d" %(time, i, j)+"\n")
                    file.flush()

def draw(tolerance, C):
    plt.xlabel(u"x1")
    plt.xlim(0, 100)
    plt.ylabel(u"x2")
    plt.ylim(0, 100)
    plt.title("SVM - %s, tolerance %f, C %f" % (train_data, tolerance, C))
    ftrain = open(train_data, "r")
    for line in ftrain:
        line = line[:-1]
        sam = line.split("\t")
        if int(sam[2]) > 0:
            plt.plot(sam[0],sam[1], 'or')
        else:
            plt.plot(sam[0],sam[1], 'og')
    
    w1 = 0.0 
    w2 = 0.0
    for i in range(len(labels)):
        w1 += params.a[i] * labels[i] * samples[i][0]
        w2 += params.a[i] * labels[i] * samples[i][1]
    w = - w1 / w2
    b = - params.b / w2
    r = 1 / w2

    lp_x1 = [10, 90]
    lp_x2 = []
    lp_x2up = []
    lp_x2down = []
    for x1 in lp_x1:
        lp_x2.append(w * x1 + b)
        lp_x2up.append(w * x1 + b + r)
        lp_x2down.append(w * x1 + b - r)
    plt.plot(lp_x1, lp_x2, 'b')
    plt.plot(lp_x1, lp_x2up, 'b--')
    plt.plot(lp_x1, lp_x2down, 'b--')
    plt.show()

if __name__ == "__main__":
    loaddata()
    print(samples)
    print(labels)
    # 惩罚系数
    # penalty for mis classify
    C = 10
    # 计算精度
    # computational accuracy 
    tolerance = 0.0001
    train(tolerance, 100, C)
    print("a = ", params.a)
    print("b = ", params.b)
    support =  []
    for i in range(len(params.a)):
        if params.a[i] > 0 and params.a[i] < C:
            support.append(samples[i])
    print("support vector = ", support)
    draw(tolerance, C)#画图

代码的公式网址如下:

代码理论网址

其中params.a表示[a1,a2,...],params.b表示b.E1=e_dict[i],跟代码理论网址的公式一一对应。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值