def compute_optimal_transport(M, r, c, lam, epsilon=1e-8):
"""
Computes the optimal transport matrix and Slinkhorn distance using the
Sinkhorn-Knopp algorithm
Inputs:
- M : cost matrix (n x m)
- r : vector of marginals (n, )
- c : vector of marginals (m, )
- lam : strength of the entropic regularization
- epsilon : convergence parameter
Outputs:
- P : optimal transport matrix (n x m)
- dist : Sinkhorn distance
"""
n, m = M.shape
P = np.exp(- lam * M)
P /= P.sum()
u = np.zeros(n)
# normalize this matrix
while np.max(np.abs(u - P.sum(1))) > epsilon:
u = P.sum(1)
P *= (r / u).reshape((-1, 1))#行归r化,注意python中*号含义
P *= (c / P.sum(0)).reshape((1, -1))#列归c化
return P, np.sum(P * M)
设M:
[[6.74620535e-01 6.51487856e-01 7.63909999e-01 1.22160802e-02 9.84285854e-01]
[5.21836427e-02 6.98448351e-01 4.21872002e-04 5.77616315e-01 9.98398433e-01]
[4.81595322e-01 8.59043865e-01 8.91100944e-01 1.27449590e-01 7.85357602e-01]
[1.40637778e-01 5.98949422e-02 5.23676192e-02 1.44150411e-02 4.74618963e-01]
[7.16849610e-01 2.82412228e-01 8.81465978e-01 2.55082618e-01 5.39586731e-01]
[4.49385127e-01 7.78590147e-01 1.31048710e-03 8.68770877e-02 6.10843349e-01]
[1.78421067e-02 7.53684632e-01 4.42902867e-01 7.38736941e-01 9.92555963e-01]
[8.16664868e-01 3.12881863e-01 5.54218820e-01 6.13135979e-01 8.86964971e-01]]
c:
[0.2 0.1 0.3 0.2 0.2]#[4,2,6,4,4]/sum([4,2,6,4,4])
r:
[0.15 0.15 0.15 0.2 0.1 0.1 0.1 0.05]#[3,3,3,4,2,2,2,1]/sum([3,3,3,4,2,2,2,1])
lam:5
则经运算输出为:
P:
[[0.01079512 0.01032116 0.01212201 0.09937232 0.01738938]
[0.04414505 0.00148527 0.10035015 0.00107044 0.0029491 ]
[0.03008735 0.0038818 0.00681382 0.05929886 0.04991816]
[0.02832576 0.03612279 0.07728489 0.0178614 0.04040515]
[0.00322536 0.02411107 0.00248517 0.01088774 0.05929066]
[0.00433108 0.00071122 0.07141918 0.00890056 0.01463795]
[0.07703699 0.00165618 0.01614092 0.00070298 0.00446294]
[0.00205329 0.02171051 0.01338386 0.00190569 0.01094666]]
P col sum:
[0.2 0.1 0.3 0.2 0.2]
P row sum:
[0.14999999 0.15000001 0.15 0.2 0.1 0.1 0.1 0.05 ]
distance:
0.24870527608715712