第三章:线性回归
基本形式
表达式:
f
(
x
)
=
w
1
x
1
+
w
2
x
2
+
.
.
.
+
w
n
x
n
+
b
f(x)=w_1x_1+w_2x_2+...+w_nx_n+b
f(x)=w1x1+w2x2+...+wnxn+b
向量形式
f
(
x
)
=
w
T
+
b
f(x)=w^T+b
f(x)=wT+b
其中
w
=
(
w
1
;
w
2
;
.
.
.
;
w
d
)
w=(w_1;w_2;...;w_d)
w=(w1;w2;...;wd)
w
w
w 和
b
b
b 学得之后模型确定。
线性回归
均方误差(最小二乘法):
( w ∗ , b ∗ ) = arg max ( w , b ) ∑ i = 1 m ( f ( x i ) − y i ) 2 = arg max ( w , b ) ∑ i = 1 m ( y i − w x i − b ) 2 (w^*,b^*)=\mathop{\arg\max}\limits_{(w,b)}\sum_{i=1}^m(f(x_i)-y_i)^2\\=\mathop{\arg\max}\limits_{(w,b)}\sum_{i=1}^m(y_i-wx_i-b)^2 (w∗,b∗)=(w,b)argmaxi=1∑m(f(xi)−yi)2=(w,b)argmaxi=1∑m(yi−wxi−b)2
几何意义:
找一条直线,使所有样本到直线上的欧式距离之和最小
torch实现线性回归
import torch
X = torch.tensor([[1,0,0],[1,1,0],[1,0,1],[1,1,1]], dtype = torch.float32)
z = torch.tensor([-0.2, -0.05, -0.05, 0.1])
w = torch.tensor([-0.2,0.15,0.15])
def LinearR(X,w):
zhat = torch.mv(X,w)
return zhat
zhat = LinearR(X,w)
对数几率回归
sigmoid函数
由于单位阶跃函数不连续,需要寻找在一定程度上替代单位阶跃函数的“替代函数”
σ
=
S
i
g
m
o
i
d
(
z
)
=
1
1
+
e
−
z
=
1
1
+
e
−
(
w
T
x
+
b
)
\sigma = Sigmoid(z)=\frac{1}{1+e^{-z}}\\=\frac{1}{1+e^{-(w^Tx+b)}}
σ=Sigmoid(z)=1+e−z1=1+e−(wTx+b)1
torch实现线性回归
X = torch.tensor([[1,0,0],[1,1,0],[1,0,1],[1,1,1]], dtype = torch.float32)
andgate = torch.tensor([[0],[0],[0],[1]], dtype = torch.float32)
w = torch.tensor([-0.2,0.15,0.15], dtype = torch.float32)
def LogisticR(X,w):
zhat = torch.mv(X,w)
sigma = torch.sigmoid(zhat) # 调用torch的sigmoid
#sigma = 1/(1+torch.exp(-zhat)) # 自己写sigmoid
andhat = torch.tensor([int(x) for x in sigma >= 0.5], dtype = torch.float32)
return sigma, andhat
sigma, andhat = LogisticR(X,w)
线性判别分析(LDA)
LDA思想
将样本投影到一条直线上,使得异类样本的中心尽可能远,同类样本的方差尽可能小
多分类学习
OVO(一对一)
将N个类别进行两两配对作,做为正反例训练一个分类器,最终会有 N ( N − 1 ) 2 \frac{N(N-1)}{2} 2N(N−1) 个分类器。对其最终结果可通过投票,将被预测的最多的类别作为最终结果
OVR(一对其余)
将一个类别作为正例其他类别做为反例训练N个分类器,在测试时,若只有一个分类器预测为正类则对应类别为最终结果,若有多个分类器为正类则考虑预测置信度,选择置信度最大的结果作为分类结果。
类别不平衡
欠采样
去除一些数量多的样本(若随机丢弃可能会丢失重要信息)
过采样
增加数量少的样本(不能简单对初始样本进行重复采样,会导致过拟合)
直接训练
在预测时调整阈值
代表性算法
欠采样
EasyEnsemble
过采样
SMOTE
参考
周志华,机器学习,清华大学出版社,2016
https://www.bilibili.com/video/BV1Mh411e7VU?p=6&spm_id_from=333.1007.top_right_bar_window_history.content.click&vd_source=ae6a9270751fdffac8724e71e288e0ec
《机器学习公式详解》