import torch
#采用的高斯核函数
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
'''
source: 源域
target:
kernel_mul:核的倍数
kernel_num:多少个核心
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
#求出高斯核函数的分母||u-v||**2
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
#上面初始化bandwidth
bandwidth /= kernel_mul ** (kernel_num // 2)#// 表示两数相除取整 **表示幂运算
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
'''
一系列操作求出分母的列表bandwidth_list,一共设置有五个核,所以求出的列表i从0、1、2、3、4共5个值的列表
'''
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
# kernel_val也求出5个值作为一个列表,求和值
return sum(kernel_val))
'''
当使用MMD方法的时候我们使用它,输入源域和目标域数据,这里是每次迭代都使用MMD方法。
print(batch_size)可以帮助你每次查看batch的大小,防止由于drop_last的原因导致的batch源域与目标域数据不匹配的现象出现。
'''
def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size = int(source.size()[0])
#print(batch_size)
kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
#kernels返回sum(kernel_val)
'''
实现高斯核函数kernels = sum(kernel_val)
'''
XX = kernels[:batch_size, :batch_size]#都是取源域的数据拼接前半部分
YY = kernels[batch_size:, batch_size:]#都是取目标域数据拼接后半部分
XY = kernels[:batch_size, batch_size:]
YX = kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY - YX)
return loss
'''
为什么使用高斯核函数,某种意义上它可以实现无限维度的映射。
'''
MMD最大均值差异代码解析
最新推荐文章于 2023-02-11 20:41:40 发布