最优运输问题

 

 

def compute_optimal_transport(M, r, c, lam, epsilon=1e-8):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm
    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter
    Outputs:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    n, m = M.shape
    P = np.exp(- lam * M)
    P /= P.sum()
    u = np.zeros(n)
    # normalize this matrix
    while np.max(np.abs(u - P.sum(1))) > epsilon:
        u = P.sum(1)
        P *= (r / u).reshape((-1, 1))#行归r化,注意python中*号含义
        P *= (c / P.sum(0)).reshape((1, -1))#列归c化
    return P, np.sum(P * M)

设M:

[[6.74620535e-01 6.51487856e-01 7.63909999e-01 1.22160802e-02  9.84285854e-01]
 
 [5.21836427e-02 6.98448351e-01 4.21872002e-04 5.77616315e-01  9.98398433e-01]
 
 [4.81595322e-01 8.59043865e-01 8.91100944e-01 1.27449590e-01  7.85357602e-01]
 
 [1.40637778e-01 5.98949422e-02 5.23676192e-02 1.44150411e-02  4.74618963e-01]
 
 [7.16849610e-01 2.82412228e-01 8.81465978e-01 2.55082618e-01  5.39586731e-01]
 
 [4.49385127e-01 7.78590147e-01 1.31048710e-03 8.68770877e-02  6.10843349e-01]
 
 [1.78421067e-02 7.53684632e-01 4.42902867e-01 7.38736941e-01  9.92555963e-01]
 
 [8.16664868e-01 3.12881863e-01 5.54218820e-01 6.13135979e-01  8.86964971e-01]]

c:

[0.2 0.1 0.3 0.2 0.2]#[4,2,6,4,4]/sum([4,2,6,4,4])

r:

[0.15 0.15 0.15 0.2  0.1  0.1  0.1  0.05]#[3,3,3,4,2,2,2,1]/sum([3,3,3,4,2,2,2,1])

lam:5

则经运算输出为:

P:

[[0.01079512 0.01032116 0.01212201 0.09937232 0.01738938]
 
 [0.04414505 0.00148527 0.10035015 0.00107044 0.0029491 ]
 
 [0.03008735 0.0038818  0.00681382 0.05929886 0.04991816]
 
 [0.02832576 0.03612279 0.07728489 0.0178614  0.04040515]
 
 [0.00322536 0.02411107 0.00248517 0.01088774 0.05929066]
 
 [0.00433108 0.00071122 0.07141918 0.00890056 0.01463795]
 
 [0.07703699 0.00165618 0.01614092 0.00070298 0.00446294]
 
 [0.00205329 0.02171051 0.01338386 0.00190569 0.01094666]]
P col sum:
[0.2 0.1 0.3 0.2 0.2]

P row sum:

[0.14999999 0.15000001 0.15       0.2        0.1        0.1     0.1        0.05      ]

distance:

0.24870527608715712
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值