mnn_距离实现

原始mnn测试

#主要的判断方式如下
import numpy as np
from sklearn.neighbors import NearestNeighbors
#import pyreadr
import numpy as np
# from sklearn.neighbors import NearestNeighbors
# x=pyreadr.read_r("x.RData")
# y=pyreadr.read_r("y.RData")
# x=x["x"].values
# y=y["y"].values
np.random.seed(1)#设置随机种子
x=np.random.randn(20,2)# x是二维的
y=np.random.randn(10,2)# y是2维的
#x ndarray,y ndarray,我想这些mnn pair
# return ndarray ,先默认是欧式距离,
def findMNN(x,y,k=10):
    neigh_y = NearestNeighbors(n_neighbors=k).fit(y)
    indice_y=neigh_y.kneighbors(x, return_distance=False)#对数据集x,在y中找它的k最近邻,返回下标
    neigh_x = NearestNeighbors(n_neighbors=k).fit(x)
    cnt=0;
    mnnset=[]
    for ind_y in indice_y:
        temp=y[ind_y]
        indice_x=neigh_x.kneighbors(temp,return_distance=False)
        row,col=np.where(indice_x==cnt)
        for temp_y in row:
            mnnset.append([cnt,ind_y[temp_y]])
            #mnnset.add((cnt,ind_y[temp_y]))
        cnt=cnt+1
    mnn_indice=np.array(mnnset)# 我不想返回indice
    print(mnn_indice)# 此处要不要返回下表
    #eturn(x[mnnset])
    #return(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#目前是返回元组,我觉得我直接返回矩阵算了
    # 就是res[0]和res[1]返回的矩阵进行拼接。如果res[0]是5维的,res[1]是5维的,那么合并后就是10维的
    res=(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#这个是元组形式
    return(np.concatenate((res[0],res[1]),axis=1))# 直接返回了
#首先返回的集合mnn pair的所有集合

#给定anchor_sample,positive_sample,集合x,y,判断这一个样本对是否是mnn pair,我觉得是不是应该直接改成向量的
def quary_xy_mnn(anchor_sample,positive_sample,set_x,set_y,k=20):
    res=findMNN(set_x,set_y,k=k);
    temp_test=np.concatenate((anchor_sample,positive_sample),axis=0)#这个是一维的,所以不存在axis=1,
    return(any((res==temp_test).all(1)))#判断该元素在不在里面

#使用案例,
#i=0;
#j=0;
#print(quary_xy_mnn(x[i],y[j],x,y))
mnnset=findMNN(x,y,k=3)# 这里是欧式距离找的mnn,如果用别的距离不知道可不可以,这个目前是没有什么问题的

# 整体的数据图
import matplotlib.pyplot as plt
plt.figure(figsize=(18,12))
plt.scatter(x[:,0],x[:,1],color="r",s=100)
plt.scatter(y[:,0],y[:,1],color="g",s=100)

for i in range(x.shape[0]):
    plt.text(x[i,0], x[i,1], str(i),fontsize=20)
for i in range(y.shape[0]):
    plt.text(y[i,0], y[i,1], str(i),fontsize=20)    
    

def connectpoints(x,y,p1,p2):# 现在仅仅画两个点
    x1, x2 = x[p1], x[p2]
    y1, y2 = y[p1], y[p2]
    plt.scatter(x1,y1,color='r',s=150)
    plt.scatter(x2,y2,color="g",s=150)
    plt.plot([x1,x2],[y1,y2])

for i in range(len(mnnset)):
    x=[mnnset[i,0],mnnset[i,2]]
    y=[mnnset[i,1],mnnset[i,3]]
    connectpoints(x,y,0,1)
#plt.axis('equal')
plt.show()
[[ 1  0]
 [ 4  5]
 [ 6  2]
 [ 6  5]
 [ 8  0]
 [ 8  2]
 [ 9  4]
 [ 9  7]
 [ 9  9]
 [10  1]
 [11  8]
 [11  9]
 [11  4]
 [12  5]
 [13  0]
 [13  2]
 [14  7]
 [14  4]
 [14  9]
 [15  6]
 [17  6]
 [17  7]
 [18  6]
 [19  8]]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TOjqf9P7-1621915296836)(output_1_1.png)]

mnn 返回距离

#主要的判断方式如下
import ipdb
import numpy as np
from sklearn.neighbors import NearestNeighbors
#import pyreadr
import numpy as np
# from sklearn.neighbors import NearestNeighbors
# x=pyreadr.read_r("x.RData")
# y=pyreadr.read_r("y.RData")
# x=x["x"].values
# y=y["y"].values
np.random.seed(1)#设置随机种子
xx=np.random.randn(20,2)# x是二维的
yy=np.random.randn(10,2)# y是2维的
x=xx.copy()
y=yy.copy()
#x ndarray,y ndarray,我想这些mnn pair
# return ndarray ,先默认是欧式距离,
def findMNN(x,y,k=10):
    neigh_y = NearestNeighbors(n_neighbors=k).fit(y)
    indice_y=neigh_y.kneighbors(x, return_distance=False)#对数据集x,在y中找它的k最近邻,返回下标,indice_y是一个二维矩阵,并不是一维的
    # 这个地方注意Indice_y的维度是[x.shape[0],2]
    neigh_x = NearestNeighbors(n_neighbors=k).fit(x)
    cnt=0;
    mnnset=[]
    dist=[]
    dist2=[]
    for ind_y in indice_y:
       
        temp=y[ind_y]# 注意这个地方ind_y并不是一个元素,而取决于k,q
        dist_x,indice_x=neigh_x.kneighbors(temp,return_distance=True)# 给定k个temp点,返回数据集x中离temp最近的这些点的index
        row,col=np.where(indice_x==cnt)# indice_x是一个k*k的矩阵
        #print(indice_x.shape)
        #print(dist_x.shape)
        for temp_y,temp_col in zip(row,col):
            #ipdb.set_trace()
            mnnset.append([cnt,ind_y[temp_y]])
            # 欧式距离
            dist.append(np.linalg.norm(x[cnt]-y[ind_y[temp_y]]))
            dist2.append(dist_x[temp_y,temp_col])
            # 如果是cosine距离
            #mnnset.add((cnt,ind_y[temp_y]))
        cnt=cnt+1
    mnn_indice=np.array(mnnset)# 我不想返回indice
    print(mnn_indice)# 此处要不要返回下表
    #eturn(x[mnnset])
    #return(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#目前是返回元组,我觉得我直接返回矩阵算了
    # 就是res[0]和res[1]返回的矩阵进行拼接。如果res[0]是5维的,res[1]是5维的,那么合并后就是10维的
    res=(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#这个是元组形式
    return(np.concatenate((res[0],res[1]),axis=1)),dist,dist2# 直接返回了
#首先返回的集合mnn pair的所有集合

#给定anchor_sample,positive_sample,集合x,y,判断这一个样本对是否是mnn pair,我觉得是不是应该直接改成向量的
def quary_xy_mnn(anchor_sample,positive_sample,set_x,set_y,k=20):
    res=findMNN(set_x,set_y,k=k);
    temp_test=np.concatenate((anchor_sample,positive_sample),axis=0)#这个是一维的,所以不存在axis=1,
    return(any((res==temp_test).all(1)))#判断该元素在不在里面

#使用案例,
#i=0;
#j=0;
#print(quary_xy_mnn(x[i],y[j],x,y))
mnnset,dist,dist2=findMNN(x,y,k=3)# 这里是欧式距离找的mnn,如果用别的距离不知道可不可以,这个目前是没有什么问题的
# 可以看到这个距离是正确的e

# 整体的数据图
import matplotlib.pyplot as plt
plt.figure(figsize=(18,12))
plt.scatter(x[:,0],x[:,1],color="r",s=100)
plt.scatter(y[:,0],y[:,1],color="g",s=100)

for i in range(x.shape[0]):
    plt.text(x[i,0], x[i,1], str(i),fontsize=20)
for i in range(y.shape[0]):
    plt.text(y[i,0], y[i,1], str(i),fontsize=20)    
    

def connectpoints(x,y,p1,p2):# 现在仅仅画两个点
    x1, x2 = x[p1], x[p2]
    y1, y2 = y[p1], y[p2]
    plt.scatter(x1,y1,color='r',s=150)
    plt.scatter(x2,y2,color="g",s=150)
    plt.plot([x1,x2],[y1,y2])

for i in range(len(mnnset)):
    x=[mnnset[i,0],mnnset[i,2]]
    y=[mnnset[i,1],mnnset[i,3]]
    connectpoints(x,y,0,1)
#plt.axis('equal')
plt.show()
[[ 1  0]
 [ 4  5]
 [ 6  2]
 [ 6  5]
 [ 8  0]
 [ 8  2]
 [ 9  4]
 [ 9  7]
 [ 9  9]
 [10  1]
 [11  8]
 [11  9]
 [11  4]
 [12  5]
 [13  0]
 [13  2]
 [14  7]
 [14  4]
 [14  9]
 [15  6]
 [17  6]
 [17  7]
 [18  6]
 [19  8]]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BHWSLnAS-1621915296840)(output_3_1.png)]

mnnset
array([[-0.52817175, -1.07296862, -0.19183555, -0.88762896],
       [ 0.3190391 , -0.24937038,  0.30017032, -0.35224985],
       [-0.3224172 , -0.38405435,  0.05080775, -0.63699565],
       [-0.3224172 , -0.38405435,  0.30017032, -0.35224985],
       [-0.17242821, -0.87785842, -0.19183555, -0.88762896],
       [-0.17242821, -0.87785842,  0.05080775, -0.63699565],
       [ 0.04221375,  0.58281521,  0.12015895,  0.61720311],
       [ 0.04221375,  0.58281521, -0.20889423,  0.58662319],
       [ 0.04221375,  0.58281521,  0.28558733,  0.88514116],
       [-1.10061918,  1.14472371, -0.74715829,  1.6924546 ],
       [ 0.90159072,  0.50249434,  0.83898341,  0.93110208],
       [ 0.90159072,  0.50249434,  0.28558733,  0.88514116],
       [ 0.90159072,  0.50249434,  0.12015895,  0.61720311],
       [ 0.90085595, -0.68372786,  0.30017032, -0.35224985],
       [-0.12289023, -0.93576943, -0.19183555, -0.88762896],
       [-0.12289023, -0.93576943,  0.05080775, -0.63699565],
       [-0.26788808,  0.53035547, -0.20889423,  0.58662319],
       [-0.26788808,  0.53035547,  0.12015895,  0.61720311],
       [-0.26788808,  0.53035547,  0.28558733,  0.88514116],
       [-0.69166075, -0.39675353, -1.1425182 , -0.34934272],
       [-0.67124613, -0.0126646 , -1.1425182 , -0.34934272],
       [-0.67124613, -0.0126646 , -0.20889423,  0.58662319],
       [-1.11731035,  0.2344157 , -1.1425182 , -0.34934272],
       [ 1.65980218,  0.74204416,  0.83898341,  0.93110208]])
ind=10
np.sqrt((mnnset[ind][0]-mnnset[ind][2])**2+(mnnset[ind][1]-mnnset[ind][3])**2)
0.4331561747236103
len(dist)
24
len(dist2)
24
dist
[0.38402191111759504,
 0.10459548875603679,
 0.4508615829852512,
 0.6233993517341146,
 0.021728060312701504,
 0.32840397302590213,
 0.08519379374484172,
 0.2511368518506405,
 0.38811297209658374,
 0.6518770783664221,
 0.4331561747236103,
 0.7251749968341934,
 0.7898061219774536,
 0.686076452116433,
 0.08408901808250398,
 0.34559624501563846,
 0.0815250311844277,
 0.3976468435897148,
 0.6574252162649653,
 0.45334338117947837,
 0.5791800412208201,
 0.7569115750541859,
 0.5843024292779387,
 0.8423101207207222]
dist2
[0.38402191111759504,
 0.10459548875603679,
 0.4508615829852512,
 0.6233993517341146,
 0.021728060312701504,
 0.32840397302590213,
 0.08519379374484172,
 0.2511368518506405,
 0.38811297209658374,
 0.6518770783664221,
 0.4331561747236103,
 0.7251749968341934,
 0.7898061219774536,
 0.686076452116433,
 0.08408901808250398,
 0.34559624501563846,
 0.0815250311844277,
 0.3976468435897148,
 0.6574252162649653,
 0.45334338117947837,
 0.5791800412208201,
 0.7569115750541859,
 0.5843024292779387,
 0.8423101207207222]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值