使用可视化方法解决统计学习方法中支持向量机的例 7.1

例 7.1 已知一个如图 7.4 所示的训练数据集,其正例点是 x 1 = ( 3 , 3 ) T x_1=(3,3)^\mathsf{T} x1=(3,3)T x 2 = ( 4 , 3 ) T x_2=(4,3)^\mathsf{T} x2=(4,3)T x 1 = ( 1 , 1 ) T x_1=(1,1)^\mathsf{T} x1=(1,1)T,试求最大间隔分离超平面。

 按照算法 7.1,根据训练数据集构造约束最优化问题:

min ⁡ 1 2 ( w 1 2 + w 2 2 ) \min \dfrac{1}{2}(w_1^2+w_2^2) min21(w12+w22)

s . t .   3 w 1 + 3 w 2 + b ⩾ 1 ( 1 )   4 w 1 + 3 w 2 + b ⩾ 1 ( 2 ) − w 1 − w 2 − b ⩾ 1 ( 3 ) \begin{aligned} s.t.\quad &~3w_1+3w_2+b\geqslant 1\qquad(1)\\ &~4w_1+3w_2+b\geqslant 1\qquad(2)\\ &-w_1-w_2-b\geqslant 1\qquad(3)\\ \end{aligned} s.t. 3w1+3w2+b1(1) 4w1+3w2+b1(2)w1w2b1(3)

  式 ( 1 ) (1) (1)加上式 ( 3 ) (3) (3)得:
w 1 + w 2 ⩾ 1 w_1+w_2\geqslant 1 w1+w21
  式 ( 1 ) (1) (1)减去式 ( 2 ) (2) (2)得:
w 1 ⩾ 1 w_1\geqslant 1 w11
  式 ( 3 ) (3) (3)乘3加上式 ( 2 ) (2) (2)得:
b ⩽ − 2 b\leqslant -2 b2

  因此可以绘制二维图如下:

  可以看出阴影部分即为最优解 w ∗ w^* w所在的区域,并且 w 1 + w 2 = 1 w_1+w_2=1 w1+w2=1与法向量的交点即为最优解 w ∗ w^* w(垂线距离最短)。

  得出 w ∗ = ( 1 2 , 1 2 ) , b = − 2 w*=(\dfrac{1}{2},\dfrac{1}{2}),b=-2 w=(21,21)b=2。于是最大间隔分离超平面为:

1 2 x ( 1 ) + 1 2 x ( 2 ) − 2 = 0 \dfrac{1}{2}x^{(1)}+\dfrac{1}{2}x^{(2)}-2=0 21x(1)+21x(2)2=0

其中, x 1 = ( 3 , 3 ) T x_1=(3,3)^\mathsf{T} x1=(3,3)T x 3 = ( 1 , 1 ) T x_3=(1,1)^\mathsf{T} x3=(1,1)T为支持向量。

可视化的代码如下:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import mpl_toolkits.axisartist as axisartist
plt.rcParams['axes.unicode_minus']=False #解决负号无法显示的问题

# 创建画布
fig=plt.figure(figsize=(8,8),dpi=300)
# 使用axisartist.Subplot方法创建一个绘图区对象ax
ax=axisartist.Subplot(fig,111)
# 将绘图区对象添加到画布中
fig.add_axes(ax)
# 通过set_visible方法设置绘图区所有坐标轴隐藏
ax.axis[:].set_visible(False)
# ax.new_floating_axis代表添加新的坐标轴
ax.axis["x"]=ax.new_floating_axis(0,0)
# 给x坐标轴加上箭头
ax.axis["x"].set_axisline_style("->",size=1.0)
# 添加y坐标轴,且加上箭头
ax.axis["y"]=ax.new_floating_axis(1,0)
ax.axis["y"].set_axisline_style("->",size=1.0)
# 设置x、y轴上刻度显示方向
ax.axis["x"].set_axis_direction("top")
ax.axis["y"].set_axis_direction("right")

# 添加同心圆
cir1=Circle(xy=(0.0,0.0),radius=2,alpha=0.5,facecolor='lightskyblue',ec='black',ls='--')
ax.add_patch(cir1)
cir2=Circle(xy=(0.0,0.0),radius=1.75,alpha=0.5,ec='black',ls='--')
ax.add_patch(cir2)
cir3=Circle(xy=(0.0,0.0),radius=3/2,alpha=0.5,ec='black',ls='--')
ax.add_patch(cir3)
cir4=Circle(xy=(0.0,0.0),radius=1.25,alpha=0.5,ec='black',ls='--')
ax.add_patch(cir4)
cir5=Circle(xy=(0.0,0.0),radius=1,alpha=1,ec='red',lw=3)
ax.add_patch(cir5)
cir6=Circle(xy=(0.0,0.0),radius=0.75,alpha=0.15,facecolor='blue',ec='black',ls='--')
ax.add_patch(cir6)
cir7=Circle(xy=(0.0,0.0),radius=1/2,alpha=0.25,facecolor='blue',ec='black',ls='--')
ax.add_patch(cir7)
cir8=Circle(xy=(0.0,0.0),radius=0.25,alpha=0.5,facecolor='blue',ec='black',ls='--')
ax.add_patch(cir8)

## 添加文本
# 设置x、y坐标轴的标签
plt.text(2.3,-0.25,r"$\omega_1$",fontsize=20)
plt.text(-0.35,2.3,r"$\omega_2$",fontsize=20)
# 线
plt.text(-2,2,r"$\omega_1+\omega_2=1$",fontsize=15)
# 最优解
plt.text(0.9,0.6,r"$\omega^*=(0.5,0.5)$",fontsize=15)
# 点
plt.text(-0.45,-0.2,r"$(0,0)$",fontsize=15)

# 设置x、y坐标轴的范围
plt.xlim(-2,2.5)
plt.ylim(-2,2.5)

# 添加直线
x=np.arange(-2.5,3,0.1)
y=1-x
plt.plot(x,y,c='r',lw=3)
y=np.zeros(x.shape[0])
plt.plot(x,y,c='r',lw=3)
# 添加线段
x_line,y_line =[[0,1/2]],[[0,1/2]]
for i in range(len(x_line)):
    ax.plot(x_line[i],y_line[i],'y')
def get_y1(x):
    return np.sqrt(1-x**2)
def get_y2(x):
    return 1-x
for i in np.arange(0.1,1,0.075):
    ax.plot([i,i],[get_y1(i)-0.025,get_y2(i)+0.025],'black')
    
# 添加点
x0,y0=0,0
ax.plot(x0,y0,c='black',marker='o')
x1,y1=1/2,1/2
ax.plot(x1,y1,c='black',marker='o')

# 保存图像
fig.savefig('7_1.png')
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值