grid_predictor.py
GridPredictor类
class GridPredictor:
def __init__(self, mask, params, mask_only_pred=False):
# set up coordinates to make dense prediction on grid 设置坐标以在网格上进行密集预测
self.device = params['device']
self.params = params
self.mask = mask
self.use_date_feats = params['use_date_feats']
self.mask_lines = (np.gradient(mask)[0]**2 + np.gradient(mask)[1]**2)
self.mask_lines[self.mask_lines > 0.0] = 1.0
# set up feature grid this will be height X width X num feats 设置特征网格 高度 x 宽度 x 数目
grid_lon = torch.linspace(-1, 1, mask.shape[1]).to(self.device)
grid_lon = grid_lon.repeat(mask.shape[0],1).unsqueeze(2)
grid_lat = torch.linspace(1, -1, mask.shape[0]).to(self.device)
grid_lat = grid_lat.repeat(mask.shape[1], 1).transpose(0,1).unsqueeze(2)
dates = torch.zeros(mask.shape[0], mask.shape[1], 1, device=self.device)
这里涉及到torch中linspace(),repeat(),unsqueeze()几个方法的用法。
if self.use_date_feats: #使用拍摄时间特征
loc_time_feats = torch.cat((grid_lon, grid_lat, dates), 2)
loc_time_feats = ut.encode_loc_time(loc_time_feats[:,:,:2], loc_time_feats[:,:,2], concat_dim=2, params=params)
else: #不使用时间特征
loc_time_feats = torch.cat((grid_lon, grid_lat), 2)
loc_time_feats = ut.encode_loc_time(loc_time_feats[:,:,:2], None, concat_dim=2, params=params)
self.feats = loc_time_feats
# for mask only prediction
if mask_only_pred:
self.mask_inds = np.where(self.mask.ravel() == 1)[0]
self.feats_local = self.feats.reshape(self.feats.shape[0]*self.feats.shape[1], self.feats.shape[2])[self.mask_inds, :].clone()
使用或不使用日期/时间特征:这里使用了拼接函数 torch.cat()
utils.py
encode_loc_time
def encode_loc_time(loc_ip, date_ip, concat_dim=1, params=None):
# assumes inputs location and date features are in range -1 to 1
#假定输入的位置和日期特征都在范围-1到1之间
# location is lon, lat
#位置是经度纬度
if params['loc_encode'] == 'encode_cos_sin':
feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)
elif params['loc_encode'] == 'encode_3D':
# X, Y, Z in 3D space
if concat_dim == 1:
cos_lon = torch.cos(math.pi*loc_ip[:, 0]).unsqueeze(-1)
sin_lon = torch.sin(math.pi*loc_ip[:, 0]).unsqueeze(-1)
cos_lat = torch.cos(math.pi*loc_ip[:, 1]).unsqueeze(-1)
sin_lat = torch.sin(math.pi*loc_ip[:, 1]).unsqueeze(-1)
if concat_dim == 2:
cos_lon = torch.cos(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
sin_lon = torch.sin(math.pi*loc_ip[:, :, 0]).unsqueeze(-1)
cos_lat = torch.cos(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
sin_lat = torch.sin(math.pi*loc_ip[:, :, 1]).unsqueeze(-1)
feats = torch.cat((cos_lon*cos_lat, sin_lon*cos_lat, sin_lat), concat_dim)
elif params['loc_encode'] == 'encode_none':
feats = loc_ip
else:
print('error - no loc feat type defined')
if params['use_date_feats']:
if params['date_encode'] == 'encode_cos_sin':
feats_date = torch.cat((torch.sin(math.pi*date_ip.unsqueeze(-1)),
torch.cos(math.pi*date_ip.unsqueeze(-1))), concat_dim)
elif params['date_encode'] == 'encode_none':
feats_date = date_ip.unsqueeze(-1)
else:
print('error - no date feat type defined')
feats = torch.cat((feats, feats_date), concat_dim)
return feats
将输入的位置和时间信息进行编码的函数。传入的参数有位置、维度、类型参数。 此处对应文章的实现细节一节中,为了地理坐标能够遍布全球,对于x的每个输入维度 l,我们采用一个映射,将每个维度映射到两个数字(规范化到范围)。
BalanceSampler类
class BalancedSampler(Sampler):
# sample "evenly" from each from class
def __init__(self, classes, num_per_class, use_replace=False, multi_label=False):
self.class_dict = {}
self.num_per_class = num_per_class
self.use_replace = use_replace
self.multi_label = multi_label
if self.multi_label:
self.class_dict = classes
else:
# standard classification
un_classes = np.unique(classes)
for cc in un_classes:
self.class_dict[cc] = []
for ii in range(len(classes)):
self.class_dict[classes[ii]].append(ii)
if self.use_replace:
self.num_exs = self.num_per_class*len(un_classes)
else:
self.num_exs = 0
for cc in self.class_dict.keys():
self.num_exs += np.minimum(len(self.class_dict[cc]), self.num_per_class)
def __iter__(self):
indices = []
for cc in self.class_dict:
if self.use_replace:
indices.extend(np.random.choice(self.class_dict[cc], self.num_per_class).tolist())
else:
indices.extend(np.random.choice(self.class_dict[cc], np.minimum(len(self.class_dict[cc]),
self.num_per_class), replace=False).tolist())
# in the multi label setting there will be duplictes at training time
np.random.shuffle(indices) # will remain a list
return iter(indices)
def __len__(self):
return self.num_exs
convert_loc_to_tensor
def convert_loc_to_tensor(x, device=None):
# intput is in lon {-180, 180}, lat {90, -90}
xt = x.astype(np.float32)
xt[:,0] /= 180.0
xt[:,1] /= 90.0
xt = torch.from_numpy(xt)
if device is not None:
xt = xt.to(device)
return xt
将位置信息转换到tensor。输入是经纬度,将经纬度都规范化到[-1,1]的范围。
一些距离函数
def distance_pw_euclidean(xx, yy):
# equivalent to scipy.spatial.distance.cdist
dist = np.sqrt((xx**2).sum(1)[:, np.newaxis] - 2*xx.dot(yy.transpose()) + ((yy**2).sum(1)[np.newaxis, :]))
return dist
def distance_pw_haversine(xx, yy, radius=6372.8):
# input should be in radians
# output is in km's if radius = 6372.8
d_lon = xx[:, 0][..., np.newaxis] - yy[:, 0][np.newaxis, ...]
d_lat = xx[:, 1][..., np.newaxis] - yy[:, 1][np.newaxis, ...]
cos_term = np.cos(xx[:,1])[..., np.newaxis]*np.cos(yy[:, 1])[np.newaxis, ...]
dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2
dist = 2 * radius * np.arcsin(np.sqrt(dist))
return dist
def euclidean_distance(xx, yy):
return np.sqrt(((xx - yy)**2).sum(1))
def haversine_distance(xx, yy, radius=6371.4):
# assumes shape N x 2, where col 0 is lat, and col 1 is lon
# input should be in radians
# output is in km's if radius = 6371.4
# note that SKLearns haversine distance is [latitude, longitude] not [longitude, latitude]
d_lon = xx[:, 0] - yy[0] #经度
d_lat = xx[:, 1] - yy[1] #纬度
cos_term = np.cos(xx[:,1])*np.cos(yy[1]) #求两点纬度的cos值的乘积
dist = np.sin(d_lat/2.0)**2 + cos_term*np.sin(d_lon/2.0)**2 #上面的乘积乘经度*0.5的sin的平方
dist = 2 * radius * np.arcsin(np.sqrt(dist + 1e-16))
return dist
其中
- R为地球半径,可取平均值 6371km;
- φ1, φ2 表示两点的纬度;
- Δλ 表示两点经度的差值。
根据2个经纬度点,计算这2个经纬度点之间的距离(通过经度纬度得到距离)
bilinear_interpolate
双线性插值
- 输入参数loc_ip是一个N x 2的向量
- 每一行都是[经度,纬度]的格式,范围均在[-1,1]内
- 数据是H x W x C维度,即高度 x 宽度 x 通道的数据矩阵
- 输出是N x C格式的被插值特征矩阵
- 将被映射到[0,1]
def bilinear_interpolate(loc_ip, data, remove_nans=False):
# loc is N x 2 vector, where each row is [lon,lat] entry
# each entry spans range [-1,1]
# data is H x W x C, height x width x channel data matrix
# op will be N x C matrix of interpolated features
# map to [0,1], then scale to data size
loc = (loc_ip.clone() + 1) / 2.0
loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
#这是因为纬度从上到下是从+90开始的,而经度从左到右是从-90到90
# longitude goes from -90 to 90 left to right
if remove_nans:
loc[torch.isnan(loc)] = 0.5
loc[:, 0] *= (data.shape[1]-1)
loc[:, 1] *= (data.shape[0]-1)
loc_int = torch.floor(loc).long() # integer pixel coordinates
xx = loc_int[:, 0]
yy = loc_int[:, 1]
xx_plus = xx + 1
xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
yy_plus = yy + 1
yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
loc_delta = loc - torch.floor(loc) # delta values
dx = loc_delta[:, 0].unsqueeze(1)
dy = loc_delta[:, 1].unsqueeze(1)
interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
return interp_val