图像质量评价指标: MMD ( maximum-mean-discrepancy) 最大平均差异

12 篇文章 1 订阅
12 篇文章 2 订阅

MMD:maximum mean discrepancy。最大平均差异, 用于判断两个分布p和q是否相同。它的基本假设是:如果对于所有以分布生成的样本空间为输入的函数f,如果两个分布生成的足够多的样本在f上的对应的像的均值都相等,那么那么可以认为这两个分布是同一个分布。现在一般用于度量两个分布之间的相似性

Keras 2.2.4
tensorflow 1.9.0

import torch
import matplotlib
import os
import argparse
import numpy as np
from PIL import Image
from torch.autograd import Variable
from keras.applications.inception_v3 import InceptionV3

os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'   # 只显示 Error

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
    将源域数据和目标域数据转化为核矩阵,即上文中的K
    Params:
        source: 源域数据(n * len(x))
        target: 目标域数据(m * len(y))
        kernel_mul:
        kernel_num: 取不同高斯核的数量
        fix_sigma: 不同高斯核的sigma值
    Return:
        sum(kernel_val): 多个核矩阵之和
    '''
    n_samples = int(source.size()[0])+int(target.size()[0])# 求矩阵的行数,一般source和target的尺度是一样的,这样便于计算
    total = torch.cat([source, target], dim=0)#将source,target按列方向合并
    #将total复制(n+m)份
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    #将total的每一行都复制成(n+m)行,即每个数据都扩展成(n+m)份
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    #求任意两个数据之间的和,得到的矩阵中坐标(i,j)代表total中第i行数据和第j行数据之间的l2 distance(i==j时为0)
    L2_distance = ((total0-total1)**2).sum(2)
    #调整高斯核函数的sigma值
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    #以fix_sigma为中值,以kernel_mul为倍数取kernel_num个bandwidth值(比如fix_sigma为1时,得到[0.25,0.5,1,2,4]
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    #高斯核函数的数学表达式
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    #得到最终的核矩阵
    return sum(kernel_val)#/len(kernel_val)

def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    '''
    计算源域数据和目标域数据的MMD距离
    Params:
        source: 源域数据(n * len(x))
        target: 目标域数据(m * len(y))
        kernel_mul:
        kernel_num: 取不同高斯核的数量
        fix_sigma: 不同高斯核的sigma值
    Return:
        loss: MMD loss
    '''
    batch_size = int(source.size()[0]) #一般默认为源域和目标域的batchsize相同
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    #根据式(3)将核矩阵分成4部分
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss#因为一般都是n==m,所以L矩阵一般不加入计算

def data_list(dirPath):
    # read img
    generatedImgs = []
    realImgs = []
    for root, dirs, files in os.walk(dirPath):
        for filename in sorted(files):
            # 判断该文件是否是目标文件
            if "generated" in filename:
                generatedPath = root + '/' + filename
                generatedImgs.append(readImg(generatedPath))
                # 对比图片路径
                realPath = root + '/' + filename.replace('generated', 'real')
                realImgs.append(readImg(realPath))
    return generatedImgs, realImgs


def readImg(imgPath):
    img = Image.open(imgPath)  # RGB
    # img.show()
    # PIL转numpy类型
    img = np.array(img).astype(np.float)
    return img/255

if __name__ == '__main__':
    ### 参数设定
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_dir', type=str, default=r'D:\Project\pix2pix-master\results', help='results')
    parser.add_argument('--name', type=str, default='faces', help='name of dataset')
    opt = parser.parse_args()

    # 数据集
    dirPath = os.path.join(opt.dataset_dir, opt.name)
    generatedImgs, realImgs = data_list(dirPath)
    size = len(generatedImgs)
    print("数据集:", size)

    X = torch.Tensor(generatedImgs)
    Y = torch.Tensor(realImgs)
    print('shape: ', X.shape, Y.shape)

    # prepare the inception v3 model
    model = InceptionV3(include_top=False, pooling='avg')
    X, Y = model.predict(X),  model.predict(Y)

    X, Y = Variable(torch.Tensor(X)), Variable(torch.Tensor(Y))

    mmd = mmd_rbf(X, Y)
    print("mmd: ", mmd)

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
PyTorch对于求maximum mean discrepancy最大均值差异)可以通过以下步骤实现: 首先,通过定义核函数来测量两个分布之间的差异。可以使用高斯核函数来计算样本的欧氏距离,并通过指定带宽参数来调整核函数的宽度。PyTorch提供了torch.exp函数来计算指数函数。 其次,计算两个分布的均值。对于每个分布,可以通过计算样本张量的平均值来得到。 然后,计算最大均值差异最大均值差异是指两个分布之间的最大差异,可以通过选择最大均值差异的值来判断两个分布是否相同。计算最大均值差异可以通过计算样本集之间的核矩阵并选择其中的最大值来实现。PyTorch提供了torch.mm函数来计算矩阵乘法,并使用torch.max函数选择最大值。 最后,将上述步骤结合起来,使用PyTorch的张量操作和数学函数来实现maximum mean discrepancy的计算。具体代码如下: ```python import torch def maximum_mean_discrepancy(x, y, bandwidth): # 计算高斯核 def gaussian_kernel(x, y, bandwidth): diff = torch.unsqueeze(x, 1) - torch.unsqueeze(y, 0) norm = torch.norm(diff, dim=2) return torch.exp(-0.5 * (norm / bandwidth) ** 2) # 计算样本均值 def mean(x): return torch.mean(x, dim=0) # 计算最大均值差异 kernel_xx = gaussian_kernel(x, x, bandwidth) kernel_xy = gaussian_kernel(x, y, bandwidth) kernel_yy = gaussian_kernel(y, y, bandwidth) mmd = torch.max(torch.mean(kernel_xx) + torch.mean(kernel_yy) - 2 * torch.mean(kernel_xy)) return mmd # 示例数据 x = torch.tensor([[1, 2], [3, 4], [5, 6]]) y = torch.tensor([[2, 3], [4, 5], [6, 7]]) bandwidth = 1 # 求解最大均值差异 mmd = maximum_mean_discrepancy(x, y, bandwidth) print(mmd) ``` 这段代码演示了如何使用高斯核函数和PyTorch的张量操作来计算两个样本集之间的最大均值差异

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

代码小白的成长

计算机网络PPT下载

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值