本文记录因子分析机FM算法的推导和理解笔记
论文地址
https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf
二阶FM推导过程
FM在预测任务是考虑了不同特征之间的交叉情况, 以2阶的交叉为例:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
∗
x
i
+
∑
i
=
1
n
∑
j
=
i
+
1
n
W
x
i
x
j
(1)
\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n}Wx_ix_j \tag{1}
y^(x)=w0+i=1∑nwi∗xi+i=1∑nj=i+1∑nWxixj(1)
其中的
w
0
w_0
w0,
w
i
w_i
wi,
W
W
W是模型需要学习的内容。由于在实际场景中,
x
i
x_i
xi,
x
j
x_j
xj都是维度很大并且稀疏的one-hot类型的向量,如果直接学习交叉项的权重
W
W
W很容易过拟合。
但是注意到
W
W
W应该是一个实对称的矩阵,由实对称矩阵理论的性质:
每个实对称矩阵
A
A
A可以分解成这样一种形式:
A
=
Q
Λ
Q
T
A=Q\Lambda Q^T
A=QΛQT ,其中
Λ
\Lambda
Λ为对角阵,
Q
Q
Q为正交矩阵
进而
W
W
W可以被分解成
W
=
V
V
T
W=VV^T
W=VVT,其中
V
∈
R
n
×
k
V \in R^{n \times k}
V∈Rn×k,所以式子(1)可以化成:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
∗
x
i
+
∑
i
=
1
n
∑
j
=
i
+
1
n
⟨
v
i
,
v
j
⟩
x
i
x
j
(2)
\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n} \langle v_i, v_j \rangle x_ix_j \tag{2}
y^(x)=w0+i=1∑nwi∗xi+i=1∑nj=i+1∑n⟨vi,vj⟩xixj(2)
v
i
v_i
vi和
v
j
v_j
vj可以用长度为
k
k
k的向量表示:
⟨
v
i
,
v
j
⟩
=
∑
f
=
1
k
v
i
,
f
⋅
v
j
,
f
\langle v_i, v_j \rangle = \sum_{f=1}^{k}v_{i,f} \cdot v_{j,f}
⟨vi,vj⟩=∑f=1kvi,f⋅vj,f
所以有:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
∗
x
i
+
∑
i
=
1
n
∑
j
=
i
+
1
n
∑
f
=
1
k
v
i
,
f
⋅
v
j
,
f
x
i
x
j
(3)
\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\sum_{i=1}^{n}\sum_{j=i+1}^{n}\sum_{f=1}^{k}v_{i,f} \cdot v_{j,f}x_ix_j \tag{3}
y^(x)=w0+i=1∑nwi∗xi+i=1∑nj=i+1∑nf=1∑kvi,f⋅vj,fxixj(3)
直接求解这个算法的时间复杂度为
O
(
k
n
2
)
O(kn^2)
O(kn2),但是可以通过调整求解方式将复杂度降为
O
(
k
n
)
O(kn)
O(kn)
令
M
=
∑
i
=
1
n
∑
j
=
i
+
1
n
∑
f
=
1
k
v
i
,
f
v
j
,
f
x
i
x
j
M=\sum_{i=1}^{n}\sum_{j=i+1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j
M=∑i=1n∑j=i+1n∑f=1kvi,fvj,fxixj
记
N
=
∑
i
=
1
n
∑
j
=
1
n
∑
f
=
1
k
v
i
,
f
v
j
,
f
x
i
x
j
N=\sum_{i=1}^{n}\sum_{j=1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j
N=∑i=1n∑j=1n∑f=1kvi,fvj,fxixj
由于:
N
=
∑
i
=
1
n
∑
j
=
1
n
∑
f
=
1
k
v
i
,
f
v
j
,
f
x
i
x
j
=
∑
i
=
1
n
∑
f
=
1
k
(
∑
j
=
1
i
−
1
v
i
,
f
v
j
,
f
x
i
x
j
+
∑
j
=
i
i
v
i
,
f
v
j
,
f
x
i
x
j
+
∑
j
=
i
+
1
n
v
i
,
f
v
j
,
f
x
i
x
j
)
=
∑
i
=
1
n
∑
f
=
1
k
(
2
∑
j
=
i
+
1
n
v
i
,
f
v
j
,
f
x
i
x
j
+
v
i
,
f
v
i
,
f
x
i
x
i
)
=
2
∑
i
=
1
n
∑
f
=
1
k
∑
j
=
i
+
1
n
v
i
,
f
v
j
,
f
x
i
x
j
+
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
v
i
,
f
x
i
x
i
=
2
M
+
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
v
i
,
f
x
i
x
i
(4)
\begin{aligned} N= & \sum_{i=1}^{n}\sum_{j=1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j \\ = & \sum_{i=1}^{n}\sum_{f=1}^{k}(\sum_{j=1}^{i-1} v_{i,f}v_{j,f}x_ix_j + \sum_{j=i}^{i} v_{i,f}v_{j,f}x_ix_j + \sum_{j=i+1}^{n} v_{i,f}v_{j,f}x_ix_j ) \\ = & \sum_{i=1}^{n}\sum_{f=1}^{k}(2\sum_{j=i+1}^{n}v_{i,f}v_{j,f}x_ix_j+ v_{i,f}v_{i,f}x_ix_i ) \\ = & 2 \sum_{i=1}^{n}\sum_{f=1}^{k}\sum_{j=i+1}^{n}v_{i,f}v_{j,f}x_ix_j + \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i \\ = & 2M+ \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i \tag{4} \end{aligned}
N=====i=1∑nj=1∑nf=1∑kvi,fvj,fxixji=1∑nf=1∑k(j=1∑i−1vi,fvj,fxixj+j=i∑ivi,fvj,fxixj+j=i+1∑nvi,fvj,fxixj)i=1∑nf=1∑k(2j=i+1∑nvi,fvj,fxixj+vi,fvi,fxixi)2i=1∑nf=1∑kj=i+1∑nvi,fvj,fxixj+i=1∑nf=1∑kvi,fvi,fxixi2M+i=1∑nf=1∑kvi,fvi,fxixi(4)
所以有:
M
=
(
N
−
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
v
i
,
f
x
i
x
i
)
/
2
=
1
2
∑
i
=
1
n
∑
j
=
1
n
∑
f
=
1
k
v
i
,
f
v
j
,
f
x
i
x
j
−
1
2
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
v
i
,
f
x
i
x
i
=
1
2
(
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
x
i
)
(
∑
j
=
1
n
∑
f
=
1
k
v
j
,
f
x
j
)
−
1
2
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
2
x
i
2
=
1
2
(
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
x
i
)
2
−
1
2
∑
i
=
1
n
∑
f
=
1
k
(
v
i
,
f
x
i
)
2
(5)
\begin{aligned} M & = (N- \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i)/2 \\ & =\frac{1}{2} \sum_{i=1}^{n}\sum_{j=1}^{n}\sum_{f=1}^{k}v_{i,f}v_{j,f}x_ix_j - \frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}v_{i,f}x_ix_i \\ &=\frac{1}{2} (\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)(\sum_{j=1}^{n}\sum_{f=1}^{k}v_{j,f}x_j) - \frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k} v_{i,f}^{2}x_i^2\\ &=\frac{1}{2} (\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)^2-\frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k}(v_{i,f}x_i)^2 \tag{5} \end{aligned}
M=(N−i=1∑nf=1∑kvi,fvi,fxixi)/2=21i=1∑nj=1∑nf=1∑kvi,fvj,fxixj−21i=1∑nf=1∑kvi,fvi,fxixi=21(i=1∑nf=1∑kvi,fxi)(j=1∑nf=1∑kvj,fxj)−21i=1∑nf=1∑kvi,f2xi2=21(i=1∑nf=1∑kvi,fxi)2−21i=1∑nf=1∑k(vi,fxi)2(5)
所以(3)式可以转化为:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
∗
x
i
+
1
2
(
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
x
i
)
2
−
1
2
∑
i
=
1
n
∑
f
=
1
k
(
v
i
,
f
x
i
)
2
(6)
\hat{y}(x)=w_0+\sum_{i=1}^{n}w_i*x_i+\frac{1}{2} (\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)^2- \frac{1}{2} \sum_{i=1}^{n}\sum_{f=1}^{k}(v_{i,f}x_i)^2 \tag{6}
y^(x)=w0+i=1∑nwi∗xi+21(i=1∑nf=1∑kvi,fxi)2−21i=1∑nf=1∑k(vi,fxi)2(6)
求解上面表达式所需要的时间复杂度为
O
(
k
n
)
O(kn)
O(kn),由于
k
≪
n
k \ll n
k≪n且为常数,所以为线性复杂度。
二阶FM反向传播
在式(6)我们要求解的为模型的权重
w
0
w_0
w0,
w
i
w_i
wi,
v
i
,
f
v_{i,f}
vi,f
对
w
0
w_0
w0求导:
∂
y
^
∂
w
0
=
1
\frac{\partial \hat y}{ \partial w_0} =1
∂w0∂y^=1
对
w
i
w_i
wi求导:
∂
y
^
∂
w
i
=
x
i
\frac{\partial \hat y}{\partial w_i} =x_i
∂wi∂y^=xi
对
v
i
,
f
v_{i,f}
vi,f求导:
∂
y
^
∂
v
i
,
f
=
(
∑
i
=
1
n
∑
f
=
1
k
v
i
,
f
x
i
)
x
i
−
v
i
,
f
x
i
⋅
x
i
\frac{\partial \hat y}{\partial v_{i,f}} =(\sum_{i=1}^{n}\sum_{f=1}^{k}v_{i,f}x_i)x_i - v_{i,f}x_i \cdot x_i
∂vi,f∂y^=(∑i=1n∑f=1kvi,fxi)xi−vi,fxi⋅xi
多阶FM
设特征直接相互交叉的类别数为d,那么有:
y
^
(
x
)
=
w
0
+
∑
i
=
1
n
w
i
x
i
+
∑
l
=
2
d
∑
i
1
=
1
n
.
.
.
∑
i
l
=
i
l
−
1
+
1
n
(
∏
j
=
1
l
x
i
j
)
(
∑
f
=
1
k
l
∏
j
=
1
l
v
i
j
,
f
(
l
)
)
\hat y(x)=w_0+\sum_{i=1}^{n}w_ix_i + \sum_{l=2}^{d}\sum_{i_1=1}^n...\sum_{i_l=i_{l-1}+1}^n(\prod_{j=1}^{l}x_{i_j})(\sum_{f=1}^{k_l}\prod_{j=1}^lv_{i_j,f}^{(l)})
y^(x)=w0+i=1∑nwixi+l=2∑di1=1∑n...il=il−1+1∑n(j=1∏lxij)(f=1∑klj=1∏lvij,f(l))
直接求解的复杂度为
O
(
k
d
n
d
)
O(kdn^d)
O(kdnd),但是可以通过上面的方法近似降成线性复杂度。