文章目录
- Towards Precise and Efficient Image Guided Depth Completion
- 实验
- 代码阅读
- main.py
- kitti_loaders.py
- def load_calib():
- def get_paths_and_transform
- def rgb_read(filename):
- def depth_read(filename):
- def drop_depth_measurements
- def train_transform(rgb, sparse, target, position, args): 对原始图片进行批处理
- def val_transform(rgb, sparse, target, position, args):
- def no_transform(rgb, sparse, target, position, args):
- to_float_tensor
- def handle_gray(rgb, args): 转化成灰度图
- def get_rgb_near(path, args):
- class KittiDepth(data.Dataset):
- CoordCov.py
- criteria.py
- helper.py
- metrics.py
- vis_utils.py
- def validcrop(img):
- def depth_colorize(depth):
- def feature_colorize(feature):
- def mask_vis(mask):
- def merge_into_row(ele, pred, predrgb=None, predg=None, extra=None, extra2=None, extrargb=None):
- def add_row(img_merge, row):
- def save_image(img_merge, filename):
- def save_image_torch(rgb, filename):
- def save_depth_as_uint16png(img, filename):
- def save_depth_as_uint16png_upload(img, filename):
- basic.py
论文作者地址:[https://github.com/JUGGHM/PENet_ICRA2021]
Towards Precise and Efficient Image Guided Depth Completion
two-branch network
-
color-dominant (CD) branch
color image and a sparse depth map --------- predicted depth map
为深度预测提取颜色主导的信息,预测的深度图在物体边界附近相对可靠,对颜色和纹理变换比较敏感
-
depth-dominant (DD) branch
sparse depth map and the previously predicted depth map ----------- dense depth map
这一分支总体上来说可靠,但是在输入的稀疏深度图中,物体边界噪声严重
-
总的来说这两个分支的优缺点比较互补,因此用learned confidence weights来融合两个分支的结果,充分利用颜色和深度信息
-
propose a simple geometric convolutional layer----encode 3D geometric cues
It simply augments a convolutional layer via concatenating a 3D position map to the layer’s input.
-
we additionally integrate a module based on CSPN++ to refine the depth map predicted by our backbone
We design a dilated and accelerated implementation of CSPN++ to make the refinement more effective and efficient.
-
SOTA
Related Work
-
A. Depth Completion
produce a dense depth map by completing a sparse depth map, without or with the guidance of a reference image
challenges
-
the input depth map is irregularly sparse and noisy;
-
the color image and the depth map are two different modalities.
-
-
B. Geometric Encoding
we propose a geometric convolutional layer to encode 3D geometric cues simply. (受CoordConv启发)
-
C. Spatial Propagation Networks
spatial propagation network (SPN)
convolutional spatial propagation network (CSPN)
CSPN++ and NLSPN are proposed very recently.
The former adaptively learns the convolutional kernel size and iteration number for propagation
The latter learns deformable kernels.
We adopt CSPN++ for our depth refinement, but we introduce a dilation scheme to enlarge the neighborhoods
and implement the propagation in a much more efficient way
METHODOLOGY
entire framework---------two-branch backbone and a depth refinement module.
-
The Two-branch Backbone
color-dominant branch
-
an aligned sparse depth map is also input to assist depth prediction
-
The encoder contains one convolution layer and ten basic residual blocks
-
The decoder has five deconvolution layers and one convolution layer
(卷积层后添加了BN和ReLu)
depth-dominant branch
- the decoder features of the color-dominant branch are concatenated with the corresponding encoder features in the depth dominant branch. (多阶段融合)
Depth fusion
-
-
The Geometric Convolutional Layer
-
augments a conventional convolutional layer via concatenating a 3D position map to the layer’s input.
-
we replace each convolutional layer within the ResBlocks by the proposed geometric convolutional layer.
-
-
The Dilated and Accelerated CSPN++
recover the depth values at valid pixels
introduce a dilation strategy similar to the well known dilated convolutions o enlarge the propagation neighborhoods.
our implementation of the translations can be performed parallelly.(more effificient)
实验
(https://github.com/JUGGHM/PENet_ICRA2021)
-
two-branch backbone
-
the geometric convolutional layer
-
the DA-CPSN++ module.
-
we obtain four variants of the backbone. B1 to B4
-
Based on the backbone model B4, we further replace each convolutional layer in the ResBlocks by our proposed geometric convolutional layer and get the model B4+GCL.
-
Based on the backbone model B4 , we integrate variants of CSPN++ to compare their performance.The total number of iterations for propagation is 12.
-
C1 stands for original CSPN++, with a dilation rate (dr) of 1 for all iterations.
-
C2 stands for the model that takes dr = 2 for first six iterations and dr = 1 for the remaining iterations.
-
C4 is the model taking dr = {4, 2, 1} for every four iterations
-
the model B4 +C2 slightly outperforms the other two counterparts
-
ENet: We also test our geometric encoded backbone without depth refinement (referred to as ENet)
-
PENet: we present the quantitative performance of our full method (referred to as PENet)
代码阅读
Pytorch中tensor的通道顺序:NCHW
不是很理解为什么这么降采样
class SparseDownSampleClose(nn.Module):
def __init__(self, stride):
super(SparseDownSampleClose, self).__init__()
self.pooling = nn.MaxPool2d(stride, stride)
self.large_number = 600
def forward(self, d, mask):
encode_d = - (1-mask)*self.large_number - d
d = - self.pooling(encode_d)
mask_result = self.pooling(mask)
d_result = d - (1-mask_result)*self.large_number
return d_result, mask_result
运行load_calib之后得到K
[[721.5377 0. 596.5593]
[ 0. 721.5377 161.354 ]
[ 0. 0. 1. ]]
BasicBlockGeo就是把残差块中的卷积部分换成了GCL
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
#norm_layer = encoding.nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
if stride != 1 or inplanes != planes:
downsample = nn.Sequential(
conv1x1(inplanes, planes, stride),
norm_layer(planes),
)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BasicBlockGeo(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None, geoplanes=3):
super(BasicBlockGeo, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
#norm_layer = encoding.nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes + geoplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes+geoplanes, planes)
self.bn2 = norm_layer(planes)
if stride != 1 or inplanes != planes:
downsample = nn.Sequential(
conv1x1(inplanes+geoplanes, planes, stride),
norm_layer(planes),
)
self.downsample = downsample
self.stride = stride
def forward(self, x, g1=None, g2=None):
identity = x
if g1 is not None:
x = torch.cat((x, g1), 1)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if g2 is not None:
out = torch.cat((g2,out), 1)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
main.py
parser
- 模型
- workers
- epochs
- start-epoch
- start-epoch-bias
- criterion
- batch-size
- lr
- weight-decay
- print-freq
- resume
- data-folder
- data-folder-rgb
- data-folder-save
- input
- val
- jitter
- rank-metric
- evaluate
- freeze-backbone
- test
- cpu
- not-random-crop
- random-crop-height
- random-crop-width
- convolutional-layer-encoding
- dilation-rate
args.use_rgb = ('rgb' in args.input)
args.use_d = 'd' in args.input
args.use_g = 'g' in args.input
args = parser.parse_args()
args.result = os.path.join('..', 'results')
args.val_h = 352
args.val_w = 1216
print(args)
cuda or gpu
cuda = torch.cuda.is_available() and not args.cpu
if cuda:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
device = torch.device("cuda")
else:
device = torch.device("cpu")
print("=> using '{}' for computation.".format(device))
loss function
depth_criterion = criteria.MaskedMSELoss() if (
args.criterion == 'l2') else criteria.MaskedL1Loss()
def iterate(mode, args, loader, model, optimizer, logger, epoch):
assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
"unsupported mode: {}".format(mode)
def main():
def main():
global args
checkpoint = None
is_eval = False
if args.evaluate:
args_new = args
if os.path.isfile(args.evaluate):
print("=> loading checkpoint '{}' ... ".format(args.evaluate),
end='')
checkpoint = torch.load(args.evaluate, map_location=device)
#args = checkpoint['args']
args.start_epoch = checkpoint['epoch'] + 1
args.data_folder = args_new.data_folder
args.val = args_new.val
is_eval = True
print("Completed.")
else:
is_eval = True
print("No model found at '{}'".format(args.evaluate))
#return
elif args.resume: # optionally resume from a checkpoint
args_new = args
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}' ... ".format(args.resume),
end='')
checkpoint = torch.load(args.resume, map_location=device)
args.start_epoch = checkpoint['epoch'] + 1
args.data_folder = args_new.data_folder
args.val = args_new.val
print("Completed. Resuming from epoch {}.".format(
checkpoint['epoch']))
else:
print("No checkpoint found at '{}'".format(args.resume))
return
print("=> creating model and optimizer ... ", end='')
model = None
penet_accelerated = False
if (args.network_model == 'e'):
model = ENet(args).to(device)
elif (is_eval == False):
if (args.dilation_rate == 1):
model = PENet_C1_train(args).to(device)
elif (args.dilation_rate == 2):
model = PENet_C2_train(args).to(device)
elif (args.dilation_rate == 4):
model = PENet_C4(args).to(device)
penet_accelerated = True
else:
if (args.dilation_rate == 1):
model = PENet_C1(args).to(device)
penet_accelerated = True
elif (args.dilation_rate == 2):
model = PENet_C2(args).to(device)
penet_accelerated = True
elif (args.dilation_rate == 4):
model = PENet_C4(args).to(device)
penet_accelerated = True
if (penet_accelerated == True):
model.encoder3.requires_grad = False
model.encoder5.requires_grad = False
model.encoder7.requires_grad = False
model_named_params = None
model_bone_params = None
model_new_params = None
optimizer = None
if checkpoint is not None:
#print(checkpoint.keys())
if (args.freeze_backbone == True):
model.backbone.load_state_dict(checkpoint['model'])
else:
model.load_state_dict(checkpoint['model'], strict=False)
#optimizer.load_state_dict(checkpoint['optimizer'])
print("=> checkpoint state loaded.")
logger = helper.logger(args)
if checkpoint is not None:
logger.best_result = checkpoint['best_result']
del checkpoint
print("=> logger created.")
test_dataset = None
test_loader = None
if (args.test):
test_dataset = KittiDepth('test_completion', args)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
pin_memory=True)
iterate("test_completion", args, test_loader, model, None, logger, 0)
return
val_dataset = KittiDepth('val', args)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=True) # set batch size to be 1 for validation
print("\t==> val_loader size:{}".format(len(val_loader)))
if is_eval == True:
for p in model.parameters():
p.requires_grad = False
result, is_best = iterate("val", args, val_loader, model, None, logger,
args.start_epoch - 1)
return
if (args.freeze_backbone == True):
for p in model.backbone.parameters():
p.requires_grad = False
model_named_params = [
p for _, p in model.named_parameters() if p.requires_grad
]
optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99))
elif (args.network_model == 'pe'):
model_bone_params = [
p for _, p in model.backbone.named_parameters() if p.requires_grad
]
model_new_params = [
p for _, p in model.named_parameters() if p.requires_grad
]
model_new_params = list(set(model_new_params) - set(model_bone_params))
optimizer = torch.optim.Adam([{'params': model_bone_params, 'lr': args.lr / 10}, {'params': model_new_params}],
lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99))
else:
model_named_params = [
p for _, p in model.named_parameters() if p.requires_grad
]
optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99))
print("completed.")
model = torch.nn.DataParallel(model)
# Data loading code
print("=> creating data loaders ... ")
if not is_eval:
train_dataset = KittiDepth('train', args)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True,
sampler=None)
print("\t==> train_loader size:{}".format(len(train_loader)))
print("=> starting main loop ...")
for epoch in range(args.start_epoch, args.epochs):
print("=> starting training epoch {} ..".format(epoch))
iterate("train", args, train_loader, model, optimizer, logger, epoch) # train for one epoch
# validation memory reset
for p in model.parameters():
p.requires_grad = False
result, is_best = iterate("val", args, val_loader, model, None, logger, epoch) # evaluate on validation set
for p in model.parameters():
p.requires_grad = True
if (args.freeze_backbone == True):
for p in model.module.backbone.parameters():
p.requires_grad = False
if (penet_accelerated == True):
model.module.encoder3.requires_grad = False
model.module.encoder5.requires_grad = False
model.module.encoder7.requires_grad = False
helper.save_checkpoint({ # save checkpoint
'epoch': epoch,
'model': model.module.state_dict(),
'best_result': logger.best_result,
'optimizer' : optimizer.state_dict(),
'args' : args,
}, is_best, epoch, logger.output_directory)
kitti_loaders.py
input_options = ['d', 'rgb', 'rgbd', 'g', 'gd']
def load_calib():
def load_calib():
使用2011_09_26的calib文件临时硬编码校准矩阵
def get_paths_and_transform
transform = no_transform
glob_d = os.path.join( args.data_folder,"data_depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png")
glob_gt = None # "test_depth_completion_anonymous/"
glob_rgb = os.path.join( args.data_folder,"data_depth_selection/test_depth_completion_anonymous/image/*.png")
paths_rgb = sorted(glob.glob(glob_rgb))
paths_gt = [None] * len(paths_rgb)
paths_d = sorted(glob.glob(glob_d))
paths = {"rgb": paths_rgb, "d": paths_d, "gt": paths_gt}
return paths, transform
def rgb_read(filename):
def rgb_read(filename):
assert os.path.exists(filename), "file not found: {}".format(filename)
img_file = Image.open(filename)
# rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1]
rgb_png = np.array(img_file, dtype='uint8') # in the range [0,255]
img_file.close()
return rgb_png
def depth_read(filename):
def depth_read(filename):
# loads depth map D from png file
# and returns it as a numpy array,
# for details see readme.txt
assert os.path.exists(filename), "file not found: {}".format(filename)
img_file = Image.open(filename)
depth_png = np.array(img_file, dtype=int)
img_file.close()
# make sure we have a proper 16bit depth map here.. not 8bit!
assert np.max(depth_png) > 255, \
"np.max(depth_png)={}, path={}".format(np.max(depth_png), filename)
depth = depth_png.astype(np.float) / 256.
# depth[depth_png == 0] = -1.
depth = np.expand_dims(depth, -1)
return depth
def drop_depth_measurements
def drop_depth_measurements(depth, prob_keep):
mask = np.random.binomial(1, prob_keep, depth.shape)
depth *= mask
return depth
def train_transform(rgb, sparse, target, position, args): 对原始图片进行批处理
def train_transform(rgb, sparse, target, position, args):
oheight = args.val_h
owidth = args.val_w
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
transforms_list = [
# transforms.Rotate(angle),
# transforms.Resize(s),
transforms.BottomCrop((oheight, owidth)),
transforms.HorizontalFlip(do_flip)
]
transform_geometric = transforms.Compose(transforms_list)
##这个类的主要作用是串联多个图片变换的操作
if sparse is not None:
sparse = transform_geometric(sparse)
target = transform_geometric(target)
if rgb is not None:
brightness = np.random.uniform(max(0, 1 - args.jitter),
1 + args.jitter)
contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter)
saturation = np.random.uniform(max(0, 1 - args.jitter),
1 + args.jitter)
transform_rgb = transforms.Compose([
transforms.ColorJitter(brightness, contrast, saturation, 0),
transform_geometric
])
rgb = transform_rgb(rgb)
if position is not None:
bottom_crop_only = transforms.Compose([transforms.BottomCrop((oheight, owidth))])
position = bottom_crop_only(position)
###裁剪
if args.not_random_crop == False:
h = oheight
w = owidth
rheight = args.random_crop_height
rwidth = args.random_crop_width
# randomlize
i = np.random.randint(0, h - rheight + 1)
j = np.random.randint(0, w - rwidth + 1)
if rgb is not None:
if rgb.ndim == 3:
rgb = rgb[i:i + rheight, j:j + rwidth, :]
elif rgb.ndim == 2:
rgb = rgb[i:i + rheight, j:j + rwidth]
if sparse is not None:
if sparse.ndim == 3:
sparse = sparse[i:i + rheight, j:j + rwidth, :]
elif sparse.ndim == 2:
sparse = sparse[i:i + rheight, j:j + rwidth]
if target is not None:
if target.ndim == 3:
target = target[i:i + rheight, j:j + rwidth, :]
elif target.ndim == 2:
target = target[i:i + rheight, j:j + rwidth]
if position is not None:
if position.ndim == 3:
position = position[i:i + rheight, j:j + rwidth, :]
elif position.ndim == 2:
position = position[i:i + rheight, j:j + rwidth]
return rgb, sparse, target, position
def val_transform(rgb, sparse, target, position, args):
def val_transform(rgb, sparse, target, position, args):
oheight = args.val_h
owidth = args.val_w
transform = transforms.Compose([
transforms.BottomCrop((oheight, owidth)),
])
if rgb is not None:
rgb = transform(rgb)
if sparse is not None:
sparse = transform(sparse)
if target is not None:
target = transform(target)
if position is not None:
position = transform(position)
return rgb, sparse, target, position
def no_transform(rgb, sparse, target, position, args):
def no_transform(rgb, sparse, target, position, args):
return rgb, sparse, target, position
to_float_tensor
to_tensor = transforms.ToTensor()
to_float_tensor = lambda x: to_tensor(x).float()
def handle_gray(rgb, args): 转化成灰度图
def handle_gray(rgb, args):
if rgb is None:
return None, None
if not args.use_g:
return rgb, None
else:
img = np.array(Image.fromarray(rgb).convert('L'))
img = np.expand_dims(img, -1)
if not args.use_rgb:
rgb_ret = None
else:
rgb_ret = rgb
return rgb_ret, img
img.convert(‘L’)
img.convert('L')
为灰度图像,每个像素用8个bit表示,0表示黑,255表示白,其他数字表示不同的灰度。
转换公式:L = R * 299/1000 + G * 587/1000+ B * 114/1000。
def get_rgb_near(path, args):
def get_rgb_near(path, args):
assert path is not None, "path is None"
def extract_frame_id(filename):
head, tail = os.path.split(filename)
number_string = tail[0:tail.find('.')]
number = int(number_string)
return head, number
def get_nearby_filename(filename, new_id):
head, _ = os.path.split(filename)
new_filename = os.path.join(head, '%010d.png' % new_id)
return new_filename
head, number = extract_frame_id(path)
count = 0
max_frame_diff = 3
candidates = [
i - max_frame_diff for i in range(max_frame_diff * 2 + 1)
if i - max_frame_diff != 0
]
while True:
random_offset = choice(candidates)
path_near = get_nearby_filename(path, number + random_offset)
if os.path.exists(path_near):
break
assert count < 20, "cannot find a nearby frame in 20 trials for {}".format(path_near)
return rgb_read(path_near)
class KittiDepth(data.Dataset):
candidates = {"rgb": rgb, "d": sparse, "gt": target, \
"g": gray, 'position': position, 'K': self.K}
items = {
key: to_float_tensor(val)
for key, val in candidates.items() if val is not None
}
return items
都转成tensor
class KittiDepth(data.Dataset):
"""A data loader for the Kitti dataset
"""
def __init__(self, split, args):
self.args = args
self.split = split
paths, transform = get_paths_and_transform(split, args)
self.paths = paths
self.transform = transform
self.K = load_calib()
self.threshold_translation = 0.1
def __getraw__(self, index):
rgb = rgb_read(self.paths['rgb'][index]) if \
(self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None
sparse = depth_read(self.paths['d'][index]) if \
(self.paths['d'][index] is not None and self.args.use_d) else None
target = depth_read(self.paths['gt'][index]) if \
self.paths['gt'][index] is not None else None
return rgb, sparse, target
def __getitem__(self, index):
rgb, sparse, target = self.__getraw__(index)
position = CoordConv.AddCoordsNp(self.args.val_h, self.args.val_w)
position = position.call()
rgb, sparse, target, position = self.transform(rgb, sparse, target, position, self.args)
rgb, gray = handle_gray(rgb, self.args)
# candidates = {"rgb": rgb, "d": sparse, "gt": target, \
# "g": gray, "r_mat": r_mat, "t_vec": t_vec, "rgb_near": rgb_near}
candidates = {"rgb": rgb, "d": sparse, "gt": target, \
"g": gray, 'position': position, 'K': self.K}
items = {
key: to_float_tensor(val)
for key, val in candidates.items() if val is not None
}
return items
def __len__(self):
return len(self.paths['gt'])
CoordCov.py
- 添加 x,y layer
criteria.py
- 选择损失函数 L1, L2
helper.py
fieldnames = [
'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10',
'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time',
'gpu_time'
]
class logger:
-output_directory
-best_result
-self.train_csv = os.path.join(output_directory, 'train.csv')
self.val_csv = os.path.join(output_directory, 'val.csv')
self.best_txt = os.path.join(output_directory, 'best.txt')
#backup the source code
print("=> creating source code backup ...")
backup_directory = os.path.join(output_directory, "code_backup")
self.backup_directory = backup_directory
backup_source_code(backup_directory)
with open(self.train_csv, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
with open(self.val_csv, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
print("=> finished creating source code backup.")
def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter, avg_meter):
if (i + 1) % self.args.print_freq == 0:
avg = avg_meter.average()
blk_avg = blk_avg_meter.average()
print('=> output: {}'.format(self.output_directory))
print(
'{split} Epoch: {0} [{1}/{2}]\tlr={lr} '
't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
.format(epoch,
i + 1,
n_set,
lr=lr,
blk_avg=blk_avg,
average=avg,
split=split.capitalize()))
blk_avg_meter.reset(False)
### 写csv文件
def conditional_save_info(self, split, average_meter, epoch):
avg = average_meter.average()
if split == "train":
csvfile_name = self.train_csv
elif split == "val":
csvfile_name = self.val_csv
elif split == "eval":
eval_filename = os.path.join(self.output_directory, 'eval.txt')
self.save_single_txt(eval_filename, avg, epoch)
return avg
elif "test" in split:
return avg
else:
raise ValueError("wrong split provided to logger")
with open(csvfile_name, 'a') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow({
'epoch': epoch,
'rmse': avg.rmse,
'photo': avg.photometric,
'mae': avg.mae,
'irmse': avg.irmse,
'imae': avg.imae,
'mse': avg.mse,
'silog': avg.silog,
'squared_rel': avg.squared_rel,
'absrel': avg.absrel,
'lg10': avg.lg10,
'delta1': avg.delta1,
'delta2': avg.delta2,
'delta3': avg.delta3,
'gpu_time': avg.gpu_time,
'data_time': avg.data_time
})
return avg
### 写txt文件
def save_single_txt(self, filename, result, epoch):
with open(filename, 'w') as txtfile:
txtfile.write(
("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" +
"mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" +
"irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" +
"absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" +
"t_gpu={:.4f}").format(self.args.rank_metric, epoch,
result.rmse, result.mae, result.silog,
result.squared_rel, result.irmse,
result.imae, result.mse, result.absrel,
result.lg10, result.delta1,
result.gpu_time))
def save_best_txt(self, result, epoch):
self.save_single_txt(self.best_txt, result, epoch)
def adjust_learning_rate(lr_init, optimizer, epoch, args):
def adjust_learning_rate(lr_init, optimizer, epoch, args):
"""Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
#lr = lr_init * (0.5**(epoch // 5))
#'''
lr = lr_init
if (args.network_model == 'pe' and args.freeze_backbone == False):
if (epoch >= 10):
lr = lr_init * 0.5
if (epoch >= 20):
lr = lr_init * 0.1
if (epoch >= 30):
lr = lr_init * 0.01
if (epoch >= 40):
lr = lr_init * 0.0005
if (epoch >= 50):
lr = lr_init * 0.00001
else:
if (epoch >= 10):
lr = lr_init * 0.5
if (epoch >= 15):
lr = lr_init * 0.1
if (epoch >= 25):
lr = lr_init * 0.01
#'''
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def save_checkpoint(state, is_best, epoch, output_directory):
def save_checkpoint(state, is_best, epoch, output_directory):
checkpoint_filename = os.path.join(output_directory,
'checkpoint-' + str(epoch) + '.pth.tar')
torch.save(state, checkpoint_filename)
if is_best:
best_filename = os.path.join(output_directory, 'model_best.pth.tar')
shutil.copyfile(checkpoint_filename, best_filename)
if epoch > 0:
prev_checkpoint_filename = os.path.join(
output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar')
if os.path.exists(prev_checkpoint_filename):
os.remove(prev_checkpoint_filename)
def backup_source_code(backup_directory):
- 备份代码函数的定义
ignore_hidden = shutil.ignore_patterns(".", "..", ".git*", "*pycache*",
"*build", "*.fuse*", "*_drive_*")
def backup_source_code(backup_directory):
if os.path.exists(backup_directory):
shutil.rmtree(backup_directory)
shutil.copytree('.', backup_directory, ignore=ignore_hidden)
def get_folder_name(args):
def get_folder_name(args):
current_time = time.strftime('%Y-%m-%d@%H-%M')
return os.path.join(args.result,
'input={}.criterion={}.lr={}.bs={}.wd={}.jitter={}.time={}'.
format(args.input, args.criterion, \
args.lr, args.batch_size, args.weight_decay, \
args.jitter, current_time
))
metrics.py
class Result
class Result(object):
def __init__(self):
self.irmse = 0
self.imae = 0
self.mse = 0
self.rmse = 0
self.mae = 0
self.absrel = 0
self.squared_rel = 0
self.lg10 = 0
self.delta1 = 0
self.delta2 = 0
self.delta3 = 0
self.data_time = 0
self.gpu_time = 0
self.silog = 0 # Scale invariant logarithmic error [log(m)*100]
self.photometric = 0
def set_to_worst(self):
self.irmse = np.inf
self.imae = np.inf
self.mse = np.inf
self.rmse = np.inf
self.mae = np.inf
self.absrel = np.inf
self.squared_rel = np.inf
self.lg10 = np.inf
self.silog = np.inf
self.delta1 = 0
self.delta2 = 0
self.delta3 = 0
self.data_time = 0
self.gpu_time = 0
def update(self, irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, \
delta1, delta2, delta3, gpu_time, data_time, silog, photometric=0):
self.irmse = irmse
self.imae = imae
self.mse = mse
self.rmse = rmse
self.mae = mae
self.absrel = absrel
self.squared_rel = squared_rel
self.lg10 = lg10
self.delta1 = delta1
self.delta2 = delta2
self.delta3 = delta3
self.data_time = data_time
self.gpu_time = gpu_time
self.silog = silog
self.photometric = photometric
def evaluate(self, output, target, photometric=0):
valid_mask = target > 0.1
# convert from meters to mm
output_mm = 1e3 * output[valid_mask]
target_mm = 1e3 * target[valid_mask]
abs_diff = (output_mm - target_mm).abs()
self.mse = float((torch.pow(abs_diff, 2)).mean())
self.rmse = math.sqrt(self.mse)
self.mae = float(abs_diff.mean())
self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean())
self.absrel = float((abs_diff / target_mm).mean())
self.squared_rel = float(((abs_diff / target_mm)**2).mean())
maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm)
self.delta1 = float((maxRatio < 1.25).float().mean())
self.delta2 = float((maxRatio < 1.25**2).float().mean())
self.delta3 = float((maxRatio < 1.25**3).float().mean())
self.data_time = 0
self.gpu_time = 0
# silog uses meters
err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask])
normalized_squared_log = (err_log**2).mean()
log_mean = err_log.mean()
self.silog = math.sqrt(normalized_squared_log -
log_mean * log_mean) * 100
# convert from meters to km
inv_output_km = (1e-3 * output[valid_mask])**(-1)
inv_target_km = (1e-3 * target[valid_mask])**(-1)
abs_inv_diff = (inv_output_km - inv_target_km).abs()
self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
self.imae = float(abs_inv_diff.mean())
self.photometric = float(photometric)
class AverageMeter(object):
class AverageMeter(object):
def __init__(self):
self.reset(time_stable=True)
def reset(self, time_stable):
self.count = 0.0
self.sum_irmse = 0
self.sum_imae = 0
self.sum_mse = 0
self.sum_rmse = 0
self.sum_mae = 0
self.sum_absrel = 0
self.sum_squared_rel = 0
self.sum_lg10 = 0
self.sum_delta1 = 0
self.sum_delta2 = 0
self.sum_delta3 = 0
self.sum_data_time = 0
self.sum_gpu_time = 0
self.sum_photometric = 0
self.sum_silog = 0
self.time_stable = time_stable
self.time_stable_counter_init = 10
self.time_stable_counter = self.time_stable_counter_init
def update(self, result, gpu_time, data_time, n=1):
self.count += n
self.sum_irmse += n * result.irmse
self.sum_imae += n * result.imae
self.sum_mse += n * result.mse
self.sum_rmse += n * result.rmse
self.sum_mae += n * result.mae
self.sum_absrel += n * result.absrel
self.sum_squared_rel += n * result.squared_rel
self.sum_lg10 += n * result.lg10
self.sum_delta1 += n * result.delta1
self.sum_delta2 += n * result.delta2
self.sum_delta3 += n * result.delta3
self.sum_data_time += n * data_time
if self.time_stable == True and self.time_stable_counter > 0:
self.time_stable_counter = self.time_stable_counter - 1
else:
self.sum_gpu_time += n * gpu_time
self.sum_silog += n * result.silog
self.sum_photometric += n * result.photometric
def average(self):
avg = Result()
if self.time_stable == True:
if self.count > 0 and self.count - self.time_stable_counter_init > 0:
avg.update(
self.sum_irmse / self.count, self.sum_imae / self.count,
self.sum_mse / self.count, self.sum_rmse / self.count,
self.sum_mae / self.count, self.sum_absrel / self.count,
self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
self.sum_delta1 / self.count, self.sum_delta2 / self.count,
self.sum_delta3 / self.count, self.sum_gpu_time / (self.count - self.time_stable_counter_init),
self.sum_data_time / self.count, self.sum_silog / self.count,
self.sum_photometric / self.count)
elif self.count > 0:
avg.update(
self.sum_irmse / self.count, self.sum_imae / self.count,
self.sum_mse / self.count, self.sum_rmse / self.count,
self.sum_mae / self.count, self.sum_absrel / self.count,
self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
self.sum_delta1 / self.count, self.sum_delta2 / self.count,
self.sum_delta3 / self.count, 0,
self.sum_data_time / self.count, self.sum_silog / self.count,
self.sum_photometric / self.count)
elif self.count > 0:
avg.update(
self.sum_irmse / self.count, self.sum_imae / self.count,
self.sum_mse / self.count, self.sum_rmse / self.count,
self.sum_mae / self.count, self.sum_absrel / self.count,
self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
self.sum_delta1 / self.count, self.sum_delta2 / self.count,
self.sum_delta3 / self.count, self.sum_gpu_time / self.count,
self.sum_data_time / self.count, self.sum_silog / self.count,
self.sum_photometric / self.count)
return avg