读代码:geo_prior(2)

grid_predictor.py

GridPredictor类 

class GridPredictor:

    def __init__(self, mask, params, mask_only_pred=False):
        # set up coordinates to make dense prediction on grid 设置坐标以在网格上进行密集预测
        self.device = params['device']
        self.params = params
        self.mask = mask
        self.use_date_feats = params['use_date_feats']

        self.mask_lines = (np.gradient(mask)[0]**2 + np.gradient(mask)[1]**2)
        self.mask_lines[self.mask_lines > 0.0] = 1.0

        # set up feature grid this will be height X width X num feats 设置特征网格 高度 x 宽度 x 数目
        grid_lon = torch.linspace(-1, 1, mask.shape[1]).to(self.device)
        grid_lon = grid_lon.repeat(mask.shape[0],1).unsqueeze(2) 
        grid_lat = torch.linspace(1, -1, mask.shape[0]).to(self.device)
        grid_lat = grid_lat.repeat(mask.shape[1], 1).transpose(0,1).unsqueeze(2)
        dates  = torch.zeros(mask.shape[0], mask.shape[1], 1, device=self.device)

 这里涉及到torch中linspace()repeat()unsqueeze()几个方法的用法。

        if self.use_date_feats:  #使用拍摄时间特征
            loc_time_feats = torch.cat((grid_lon, grid_lat, dates), 2)
            loc_time_feats = ut.encode_loc_time(loc_time_feats[:,:,:2], loc_time_feats[:,:,2], concat_dim=2, params=params)
        else:  #不使用时间特征
            loc_time_feats = torch.cat((grid_lon, grid_lat), 2)
            loc_time_feats = ut.encode_loc_time(loc_time_feats[:,:,:2], None, concat_dim=2, params=params)
        self.feats = loc_time_feats

        # for mask only prediction
        if mask_only_pred:
            self.mask_inds = np.where(self.mask.ravel() == 1)[0]
            self.feats_local = self.feats.reshape(self.feats.shape[0]*self.feats.shape[1], self.feats.shape[2])[self.mask_inds, :].clone()

使用或不使用日期/时间特征:这里使用了拼接函数 torch.cat() 

utils.py

encode_loc_time 

def encode_loc_time(loc_ip, date_ip, concat_dim=1, params=None):
    # assumes inputs location and date features are in range -1 to 1
    #假定输入的位置和日期特征都在范围-1到1之间
    # location is lon, lat
    #位置是经度纬度

    if params['loc_encode'] == 'encode_cos_sin':
        feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)

    elif params['loc_encode'] == 'encode_3D':
        # X, Y, Z in 3D space
        if concat_dim == 1:
            cos_lon = torch.cos(math.pi*loc_ip[:, 0]).unsqueeze(-1)
            sin_lon = torch.sin(math.pi*loc_ip[:, 0]).unsqueeze(-1)
            cos_lat = torch.cos(math.pi*loc_ip[:, 1]).unsqueeze(-1)
            sin_lat = torch.sin(math.pi*loc_ip[:, 1]).unsqueeze(-1)
        if concat_dim == 2:
            cos_lon = torch.cos(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
            sin_lon = torch.sin(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
            cos_lat = torch.cos(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
            sin_lat = torch.sin(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
        feats = torch.cat((cos_lon*cos_lat, sin_lon*cos_lat, sin_lat), concat_dim)

    elif params['loc_encode'] == 'encode_none':
        feats = loc_ip

    else:
        print('error - no loc feat type defined')


    if params['use_date_feats']:
        if params['date_encode'] == 'encode_cos_sin':
            feats_date = torch.cat((torch.sin(math.pi*date_ip.unsqueeze(-1)),
                                    torch.cos(math.pi*date_ip.unsqueeze(-1))), concat_dim)
        elif params['date_encode'] == 'encode_none':
            feats_date = date_ip.unsqueeze(-1)
        else:
            print('error - no date feat type defined')
        feats = torch.cat((feats, feats_date), concat_dim)

    return feats

将输入的位置和时间信息进行编码的函数。传入的参数有位置、维度、类型参数。 此处对应文章的实现细节一节中,为了地理坐标能够遍布全球,对于x的每个输入维度 l,我们采用一个映射[sin(\pi x^{l}),cos(\pi x^{l})],将每个维度映射到两个数字(规范化到范围x^{l} \in [-1,1])。

 BalanceSampler类

class BalancedSampler(Sampler):
    # sample "evenly" from each from class
    def __init__(self, classes, num_per_class, use_replace=False, multi_label=False):
        self.class_dict = {}
        self.num_per_class = num_per_class
        self.use_replace = use_replace
        self.multi_label = multi_label

        if self.multi_label:
            self.class_dict = classes
        else:
            # standard classification
            un_classes = np.unique(classes)
            for cc in un_classes:
                self.class_dict[cc] = []

            for ii in range(len(classes)):
                self.class_dict[classes[ii]].append(ii)

        if self.use_replace:
            self.num_exs = self.num_per_class*len(un_classes)
        else:
            self.num_exs = 0
            for cc in self.class_dict.keys():
                self.num_exs += np.minimum(len(self.class_dict[cc]), self.num_per_class)


    def __iter__(self):
        indices = []
        for cc in self.class_dict:
            if self.use_replace:
                indices.extend(np.random.choice(self.class_dict[cc], self.num_per_class).tolist())
            else:
                indices.extend(np.random.choice(self.class_dict[cc], np.minimum(len(self.class_dict[cc]),
                                                self.num_per_class), replace=False).tolist())
        # in the multi label setting there will be duplictes at training time
        np.random.shuffle(indices)  # will remain a list
        return iter(indices)

    def __len__(self):
        return self.num_exs

convert_loc_to_tensor 

def convert_loc_to_tensor(x, device=None):
    # intput is in lon {-180, 180}, lat {90, -90}
    xt = x.astype(np.float32)
    xt[:,0] /= 180.0
    xt[:,1] /= 90.0
    xt = torch.from_numpy(xt)
    if device is not None:
        xt = xt.to(device)
    return xt

将位置信息转换到tensor。输入是经纬度,将经纬度都规范化到[-1,1]的范围。

 一些距离函数

def distance_pw_euclidean(xx, yy):
    # equivalent to scipy.spatial.distance.cdist
    dist = np.sqrt((xx**2).sum(1)[:, np.newaxis] - 2*xx.dot(yy.transpose()) + ((yy**2).sum(1)[np.newaxis, :]))
    return dist


def distance_pw_haversine(xx, yy, radius=6372.8):
    # input should be in radians
    # output is in km's if radius = 6372.8

    d_lon = xx[:, 0][..., np.newaxis] - yy[:, 0][np.newaxis, ...]
    d_lat = xx[:, 1][..., np.newaxis] - yy[:, 1][np.newaxis, ...]

    cos_term = np.cos(xx[:,1])[..., np.newaxis]*np.cos(yy[:, 1])[np.newaxis, ...]
    dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2
    dist = 2 * radius * np.arcsin(np.sqrt(dist))
    return dist


def euclidean_distance(xx, yy):
    return np.sqrt(((xx - yy)**2).sum(1))


def haversine_distance(xx, yy, radius=6371.4):
    # assumes shape N x 2, where col 0 is lat, and col 1 is lon
    # input should be in radians
    # output is in km's if radius = 6371.4
    # note that SKLearns haversine distance is [latitude, longitude] not [longitude, latitude]

    d_lon = xx[:, 0] - yy[0]    #经度
    d_lat = xx[:, 1] - yy[1]    #纬度

    cos_term = np.cos(xx[:,1])*np.cos(yy[1]) #求两点纬度的cos值的乘积
    dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2  #上面的乘积乘经度*0.5的sin的平方
    dist = 2 * radius * np.arcsin(np.sqrt(dist + 1e-16))

    return dist

引用地理空间距离计算优化

其中

  • R为地球半径,可取平均值 6371km;
  • φ1, φ2 表示两点的纬度;
  • Δλ 表示两点经度的差值。

根据2个经纬度点,计算这2个经纬度点之间的距离(通过经度纬度得到距离)

bilinear_interpolate 

双线性插值 

  • 输入参数loc_ip是一个N x 2的向量
  • 每一行都是[经度,纬度]的格式,范围均在[-1,1]内
  • 数据是H x W x C维度,即高度 x 宽度 x 通道的数据矩阵
  • 输出是N x C格式的被插值特征矩阵
  • 将被映射到[0,1]
def bilinear_interpolate(loc_ip, data, remove_nans=False):
    # loc is N x 2 vector, where each row is [lon,lat] entry
    #   each entry spans range [-1,1]
    # data is H x W x C, height x width x channel data matrix
    # op will be N x C matrix of interpolated features

    # map to [0,1], then scale to data size
    loc = (loc_ip.clone() + 1) / 2.0
    loc[:,1] = 1 - loc[:,1]   # this is because latitude goes from +90 on top to bottom while
#这是因为纬度从上到下是从+90开始的,而经度从左到右是从-90到90
                              # longitude goes from -90 to 90 left to right
    if remove_nans:
        loc[torch.isnan(loc)] = 0.5
    loc[:, 0] *= (data.shape[1]-1)
    loc[:, 1] *= (data.shape[0]-1)

    loc_int = torch.floor(loc).long()  # integer pixel coordinates
    xx = loc_int[:, 0]
    yy = loc_int[:, 1]
    xx_plus = xx + 1
    xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
    yy_plus = yy + 1
    yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1

    loc_delta = loc - torch.floor(loc)   # delta values
    dx = loc_delta[:, 0].unsqueeze(1)
    dy = loc_delta[:, 1].unsqueeze(1)
    interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
                 data[yy_plus, xx, :]*(1-dx)*dy   + data[yy_plus, xx_plus, :]*dx*dy

    return interp_val

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值