计算训练集,测试集的距离

计算训练集中的数据与训练集中每个数据的距离(之后为测试集中每个数据找出训练集中离它距离最小的k个)
用第二种方法,向量化计算距离的效率高

  def compute_distances_two_loops(self, X):
    """
    Compute the distance between each test point in X and each training point
self.X_train is training data and the 
   X is test data.

    Inputs:
    - X: A numpy array of shape (num_test, D)((500,3072)) containing test data.
self.X_train :(5000,3072)
    Returns:
    - dists: A numpy array of shape (num_test, num_train) ((500,5000))where dists[i, j]
      is the Euclidean distance between the ith test point and the jth training
      point.
    """
    num_test = X.shape[0]
    num_train = self.X_train.shape[0]
    dists = np.zeros((num_test, num_train))
    for i in range(num_test):
      for j in range(num_train):
        dist = np.sqrt(np.sum(np.square(X[i] - self.X_train[j])))
        dists[i, j] = dist
    return dists
    
  def compute_distances_no_loops(self, X):
    """
    Compute the distance between each test point in X and each training point
    in self.X_train using no explicit loops.

    Input / Output: Same as compute_distances_two_loops
    """
    num_test = X.shape[0]
    num_train = self.X_train.shape[0]
    dists = np.zeros((num_test, num_train)) 

    #formulate the l2 distance using matrix multiplication    #
  
    M = np.dot(X, self.X_train.T)
    print(X.shape, self.X_train.shape)
    print(M.shape)
    nrow, ncol = M.shape[0], M.shape[1]
    te = np.diag(np.dot(X, X.T))#the element on digonal is quardratic sum of every vector of X
    tr = np.diag(np.dot(self.X_train, self.X_train.T))
    te = np.reshape(np.repeat(te, ncol), M.shape)#copy M.shape times
    tr = np.reshape(np.repeat(tr, nrow), M.T.shape)
    distance_square = -2 * M + te + tr.T
    dists = np.sqrt(distance_square)
    return dists
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值