原始mnn测试
#主要的判断方式如下
import numpy as np
from sklearn.neighbors import NearestNeighbors
#import pyreadr
import numpy as np
# from sklearn.neighbors import NearestNeighbors
# x=pyreadr.read_r("x.RData")
# y=pyreadr.read_r("y.RData")
# x=x["x"].values
# y=y["y"].values
np.random.seed(1)#设置随机种子
x=np.random.randn(20,2)# x是二维的
y=np.random.randn(10,2)# y是2维的
#x ndarray,y ndarray,我想这些mnn pair
# return ndarray ,先默认是欧式距离,
def findMNN(x,y,k=10):
neigh_y = NearestNeighbors(n_neighbors=k).fit(y)
indice_y=neigh_y.kneighbors(x, return_distance=False)#对数据集x,在y中找它的k最近邻,返回下标
neigh_x = NearestNeighbors(n_neighbors=k).fit(x)
cnt=0;
mnnset=[]
for ind_y in indice_y:
temp=y[ind_y]
indice_x=neigh_x.kneighbors(temp,return_distance=False)
row,col=np.where(indice_x==cnt)
for temp_y in row:
mnnset.append([cnt,ind_y[temp_y]])
#mnnset.add((cnt,ind_y[temp_y]))
cnt=cnt+1
mnn_indice=np.array(mnnset)# 我不想返回indice
print(mnn_indice)# 此处要不要返回下表
#eturn(x[mnnset])
#return(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#目前是返回元组,我觉得我直接返回矩阵算了
# 就是res[0]和res[1]返回的矩阵进行拼接。如果res[0]是5维的,res[1]是5维的,那么合并后就是10维的
res=(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#这个是元组形式
return(np.concatenate((res[0],res[1]),axis=1))# 直接返回了
#首先返回的集合mnn pair的所有集合
#给定anchor_sample,positive_sample,集合x,y,判断这一个样本对是否是mnn pair,我觉得是不是应该直接改成向量的
def quary_xy_mnn(anchor_sample,positive_sample,set_x,set_y,k=20):
res=findMNN(set_x,set_y,k=k);
temp_test=np.concatenate((anchor_sample,positive_sample),axis=0)#这个是一维的,所以不存在axis=1,
return(any((res==temp_test).all(1)))#判断该元素在不在里面
#使用案例,
#i=0;
#j=0;
#print(quary_xy_mnn(x[i],y[j],x,y))
mnnset=findMNN(x,y,k=3)# 这里是欧式距离找的mnn,如果用别的距离不知道可不可以,这个目前是没有什么问题的
# 整体的数据图
import matplotlib.pyplot as plt
plt.figure(figsize=(18,12))
plt.scatter(x[:,0],x[:,1],color="r",s=100)
plt.scatter(y[:,0],y[:,1],color="g",s=100)
for i in range(x.shape[0]):
plt.text(x[i,0], x[i,1], str(i),fontsize=20)
for i in range(y.shape[0]):
plt.text(y[i,0], y[i,1], str(i),fontsize=20)
def connectpoints(x,y,p1,p2):# 现在仅仅画两个点
x1, x2 = x[p1], x[p2]
y1, y2 = y[p1], y[p2]
plt.scatter(x1,y1,color='r',s=150)
plt.scatter(x2,y2,color="g",s=150)
plt.plot([x1,x2],[y1,y2])
for i in range(len(mnnset)):
x=[mnnset[i,0],mnnset[i,2]]
y=[mnnset[i,1],mnnset[i,3]]
connectpoints(x,y,0,1)
#plt.axis('equal')
plt.show()
[[ 1 0]
[ 4 5]
[ 6 2]
[ 6 5]
[ 8 0]
[ 8 2]
[ 9 4]
[ 9 7]
[ 9 9]
[10 1]
[11 8]
[11 9]
[11 4]
[12 5]
[13 0]
[13 2]
[14 7]
[14 4]
[14 9]
[15 6]
[17 6]
[17 7]
[18 6]
[19 8]]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TOjqf9P7-1621915296836)(output_1_1.png)]
mnn 返回距离
#主要的判断方式如下
import ipdb
import numpy as np
from sklearn.neighbors import NearestNeighbors
#import pyreadr
import numpy as np
# from sklearn.neighbors import NearestNeighbors
# x=pyreadr.read_r("x.RData")
# y=pyreadr.read_r("y.RData")
# x=x["x"].values
# y=y["y"].values
np.random.seed(1)#设置随机种子
xx=np.random.randn(20,2)# x是二维的
yy=np.random.randn(10,2)# y是2维的
x=xx.copy()
y=yy.copy()
#x ndarray,y ndarray,我想这些mnn pair
# return ndarray ,先默认是欧式距离,
def findMNN(x,y,k=10):
neigh_y = NearestNeighbors(n_neighbors=k).fit(y)
indice_y=neigh_y.kneighbors(x, return_distance=False)#对数据集x,在y中找它的k最近邻,返回下标,indice_y是一个二维矩阵,并不是一维的
# 这个地方注意Indice_y的维度是[x.shape[0],2]
neigh_x = NearestNeighbors(n_neighbors=k).fit(x)
cnt=0;
mnnset=[]
dist=[]
dist2=[]
for ind_y in indice_y:
temp=y[ind_y]# 注意这个地方ind_y并不是一个元素,而取决于k,q
dist_x,indice_x=neigh_x.kneighbors(temp,return_distance=True)# 给定k个temp点,返回数据集x中离temp最近的这些点的index
row,col=np.where(indice_x==cnt)# indice_x是一个k*k的矩阵
#print(indice_x.shape)
#print(dist_x.shape)
for temp_y,temp_col in zip(row,col):
#ipdb.set_trace()
mnnset.append([cnt,ind_y[temp_y]])
# 欧式距离
dist.append(np.linalg.norm(x[cnt]-y[ind_y[temp_y]]))
dist2.append(dist_x[temp_y,temp_col])
# 如果是cosine距离
#mnnset.add((cnt,ind_y[temp_y]))
cnt=cnt+1
mnn_indice=np.array(mnnset)# 我不想返回indice
print(mnn_indice)# 此处要不要返回下表
#eturn(x[mnnset])
#return(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#目前是返回元组,我觉得我直接返回矩阵算了
# 就是res[0]和res[1]返回的矩阵进行拼接。如果res[0]是5维的,res[1]是5维的,那么合并后就是10维的
res=(x[mnn_indice[:,0]],y[mnn_indice[:,1]])#这个是元组形式
return(np.concatenate((res[0],res[1]),axis=1)),dist,dist2# 直接返回了
#首先返回的集合mnn pair的所有集合
#给定anchor_sample,positive_sample,集合x,y,判断这一个样本对是否是mnn pair,我觉得是不是应该直接改成向量的
def quary_xy_mnn(anchor_sample,positive_sample,set_x,set_y,k=20):
res=findMNN(set_x,set_y,k=k);
temp_test=np.concatenate((anchor_sample,positive_sample),axis=0)#这个是一维的,所以不存在axis=1,
return(any((res==temp_test).all(1)))#判断该元素在不在里面
#使用案例,
#i=0;
#j=0;
#print(quary_xy_mnn(x[i],y[j],x,y))
mnnset,dist,dist2=findMNN(x,y,k=3)# 这里是欧式距离找的mnn,如果用别的距离不知道可不可以,这个目前是没有什么问题的
# 可以看到这个距离是正确的e
# 整体的数据图
import matplotlib.pyplot as plt
plt.figure(figsize=(18,12))
plt.scatter(x[:,0],x[:,1],color="r",s=100)
plt.scatter(y[:,0],y[:,1],color="g",s=100)
for i in range(x.shape[0]):
plt.text(x[i,0], x[i,1], str(i),fontsize=20)
for i in range(y.shape[0]):
plt.text(y[i,0], y[i,1], str(i),fontsize=20)
def connectpoints(x,y,p1,p2):# 现在仅仅画两个点
x1, x2 = x[p1], x[p2]
y1, y2 = y[p1], y[p2]
plt.scatter(x1,y1,color='r',s=150)
plt.scatter(x2,y2,color="g",s=150)
plt.plot([x1,x2],[y1,y2])
for i in range(len(mnnset)):
x=[mnnset[i,0],mnnset[i,2]]
y=[mnnset[i,1],mnnset[i,3]]
connectpoints(x,y,0,1)
#plt.axis('equal')
plt.show()
[[ 1 0]
[ 4 5]
[ 6 2]
[ 6 5]
[ 8 0]
[ 8 2]
[ 9 4]
[ 9 7]
[ 9 9]
[10 1]
[11 8]
[11 9]
[11 4]
[12 5]
[13 0]
[13 2]
[14 7]
[14 4]
[14 9]
[15 6]
[17 6]
[17 7]
[18 6]
[19 8]]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BHWSLnAS-1621915296840)(output_3_1.png)]
mnnset
array([[-0.52817175, -1.07296862, -0.19183555, -0.88762896],
[ 0.3190391 , -0.24937038, 0.30017032, -0.35224985],
[-0.3224172 , -0.38405435, 0.05080775, -0.63699565],
[-0.3224172 , -0.38405435, 0.30017032, -0.35224985],
[-0.17242821, -0.87785842, -0.19183555, -0.88762896],
[-0.17242821, -0.87785842, 0.05080775, -0.63699565],
[ 0.04221375, 0.58281521, 0.12015895, 0.61720311],
[ 0.04221375, 0.58281521, -0.20889423, 0.58662319],
[ 0.04221375, 0.58281521, 0.28558733, 0.88514116],
[-1.10061918, 1.14472371, -0.74715829, 1.6924546 ],
[ 0.90159072, 0.50249434, 0.83898341, 0.93110208],
[ 0.90159072, 0.50249434, 0.28558733, 0.88514116],
[ 0.90159072, 0.50249434, 0.12015895, 0.61720311],
[ 0.90085595, -0.68372786, 0.30017032, -0.35224985],
[-0.12289023, -0.93576943, -0.19183555, -0.88762896],
[-0.12289023, -0.93576943, 0.05080775, -0.63699565],
[-0.26788808, 0.53035547, -0.20889423, 0.58662319],
[-0.26788808, 0.53035547, 0.12015895, 0.61720311],
[-0.26788808, 0.53035547, 0.28558733, 0.88514116],
[-0.69166075, -0.39675353, -1.1425182 , -0.34934272],
[-0.67124613, -0.0126646 , -1.1425182 , -0.34934272],
[-0.67124613, -0.0126646 , -0.20889423, 0.58662319],
[-1.11731035, 0.2344157 , -1.1425182 , -0.34934272],
[ 1.65980218, 0.74204416, 0.83898341, 0.93110208]])
ind=10
np.sqrt((mnnset[ind][0]-mnnset[ind][2])**2+(mnnset[ind][1]-mnnset[ind][3])**2)
0.4331561747236103
len(dist)
24
len(dist2)
24
dist
[0.38402191111759504,
0.10459548875603679,
0.4508615829852512,
0.6233993517341146,
0.021728060312701504,
0.32840397302590213,
0.08519379374484172,
0.2511368518506405,
0.38811297209658374,
0.6518770783664221,
0.4331561747236103,
0.7251749968341934,
0.7898061219774536,
0.686076452116433,
0.08408901808250398,
0.34559624501563846,
0.0815250311844277,
0.3976468435897148,
0.6574252162649653,
0.45334338117947837,
0.5791800412208201,
0.7569115750541859,
0.5843024292779387,
0.8423101207207222]
dist2
[0.38402191111759504,
0.10459548875603679,
0.4508615829852512,
0.6233993517341146,
0.021728060312701504,
0.32840397302590213,
0.08519379374484172,
0.2511368518506405,
0.38811297209658374,
0.6518770783664221,
0.4331561747236103,
0.7251749968341934,
0.7898061219774536,
0.686076452116433,
0.08408901808250398,
0.34559624501563846,
0.0815250311844277,
0.3976468435897148,
0.6574252162649653,
0.45334338117947837,
0.5791800412208201,
0.7569115750541859,
0.5843024292779387,
0.8423101207207222]