注:本文只是本人阅读西瓜书及南瓜书的阅读笔记和心得,可能只有自己能看懂,鉴于本人水平有限,有极大可能出现错误,欢迎读者批评指正
1、预测函数
写成向量的形式为:
其中 w 为参数向量 b 为偏置量
2、线性回归
线性回归的思想就是通过给定的数据集拟合出一条最符合该数据集变化趋势的“直线”
线性回归的任务就是求出预测函数的w和b,使其损失最小
(下图是西瓜书部分公式推导,......我也不一定看得懂我写的......字太丑了)
将偏置b和参数w写成一个向量,则线性回归的任务就是求向量
的最小值,同样使用最小二乘法求损失函数的最小值
将上式对求导,就可以计算出最优解
3、 广义线性模型
上式依旧是线性模型,函数是单调可微的
4、对数几率回归
对广义线性模型中时,那么该模型就是对数线性回归
对于二分类任务,其输出值只有两个结果,即yes or no,而对于前文所说的线性回归来说,其输出值是一个实值,所以我们要将输出值转化成0/1,对于二分类任务最理想的
就是单位阶跃函数:
因为单位阶跃函数不连续,所以能作为广义线性模型中的,因此,这里用Sigmoid函数作为单位阶跃函数的代替。
将上式变形得:
y可以作为正样本,1-y可以作为负样本,那么其比值称为“几率”(odds),取对数就称作“对数几率”(log odds,亦称 logit)
勘误:上图(3.24) (3.24)
对于后验概率,我们可以通过极大似然法来求出参数
取对数得:
(3.25)
为方便计算,下面将写成
(可参考上文)
对于(3.23)以及(3.24)进行整合
(3.26)
即当y=0时,上式等于(3.24),当y=1时,上式等于(3.23)
将(3.26)带入(3.25)得(这里就不推导了):
上式是极大似然估计得似然函数,所以最大化似然函数等价于最小化似然函数的相反数
即: (3.27)
对上式求得最优解。这里就不赘述。
5、线性判别分析(Linear Discriminant Analysis)
上图为西瓜书上线性判别原图3.3
LDA的思想很朴素:对于给定的训练集,设法将样例投影到一条直线上,使得同类样例的投影点尽可能的靠近、不同类型样例的投影点尽可能的远离。
对于上式,分子分母都是关于w的二项式,因此其最值至于w方向有关,与其长度无关。
6、多分类学习
(后续补充。。。欢迎讨论)
import torch
lr=0.01
# 1.准备数据
# y=3x+0.8
x=torch.rand([500,1])
y_ture=x*0.3+0.8
# 2.通过模型计算y_predict
w=torch.rand([1,1],requires_grad=True)
b=torch.tensor(0.,requires_grad=True)
# 4. 通过循环,反向传播,更新参数
for i in range(500):
y_predict=torch.matmul(x,w)+b
# 3.计算loss
loss = (y_ture - y_predict).pow(2).mean()
if w.grad is not None:
w.grad.zero_() # 原地修改
if b.grad is not None:
b.grad.zero_()
loss.backward()
w.data=w.data-lr*w.grad
b.data=b.data-lr*b.grad
if i%50==0:
print(f"w={w.item()}, b={b.item()},loss={loss}")
w=0.44416868686676025, b=0.0146456528455019,loss=0.5378040671348572
w=0.6275120973587036, b=0.4291453957557678,loss=0.05473712459206581
w=0.66026771068573, b=0.5544319152832031,loss=0.01557918917387724
w=0.652208149433136, b=0.5984543561935425,loss=0.01113317720592022
w=0.6338667869567871, b=0.6192365884780884,loss=0.009559470228850842
w=0.6136873364448547, b=0.632994532585144,loss=0.008361082524061203
w=0.5939584970474243, b=0.6442849636077881,loss=0.0073250518180429935
w=0.5752568244934082, b=0.6544107794761658,loss=0.00641834270209074
w=0.5576854944229126, b=0.6637656688690186,loss=0.005623943638056517
w=0.5412192940711975, b=0.6724882125854492,loss=0.0049278754740953445Process finished with exit code 0