引言(Introduction)
对于第一次接触神经网络框架的同学来说,可能会碰到看不懂源码,或者碰到大量代码不知如何下手的情况。因此本文章的侧重点是解释源码的逻辑和含义,同时简单扩展相关知识点。为了保证行文的总体逻辑,部分较为深入的知识会放在附录。若有不当,欢迎大家批评指正。
一、以PyTorch-SuperPoint为例
关键点检测和描述符生成是基础但至关重要的任务,它们被广泛应用于图像匹配、3D重建、物体识别等多个场景。本文以PyTorch-SuperPoint项目为例,介绍图像领域的源码结构。PyTorch-SuperPointhttps://github.com/eric-yyjau/pytorch-superpoint
1.1 了解项目框架以及使用--README.md
当我们接触到一个新项目时,最先阅读的就是README.md,该文件中会有作者给出的关于该项目的介绍,其一般会包含以下内容,部分作者还会给出代码的实现效果以及各模块的实现逻辑以及功能。
- 代码框架,每个文件功能或者文件夹存储着什么文件.
- 所需的环境配置,需要下载安装哪些包.
- 运行代码的方法

1.2 安装配置环境 -- requirements.txt
要跑通项目,首先得先安装配置环境,我们先看到README.md中的介绍,该项目的用到的时python 3.6,pytorch 1.3.1,torchvision 0.4.2,cuda 10,在安装配置环境的时候,尽量配置与项目给出的版本相同,不然很可能出现版本不兼容报错的情况。
- python 编程语言,其版本主要分为2.x和3.x两种,这两种差别比较大,
- pytorch,torchvision python库,你可以把它理解为工具箱,一些底层的实现都被封装好了,你只需要调用库里的类或者函数就行
- cuda 并行计算平台,我们可以通过cuda,将计算任务分配给GPU进行并行处理,从而加速各种计算密集型应用程序。
1.2.1 安装cuda
首先,我们来安装cuda。这块可能会碰到很多名词:GPU(显卡),显卡驱动,CUDA,CUDA Toolkit,cudnn,它们之间关系见附录1,这里我们只讲如何进行安装。
(1)检查电脑是否支持安装cuda
由于CUDA是利用GPU的并行计算能力来进行加速,所以要求你的电脑上必须有GPU,也就是我们常说的”显卡“,且该GPU支持CUDA,其检查步骤如下:
右键'此电脑'->点击属性->选择设备管理器->选择显示设备器,即可获得你的GPU,然后在官网上看你的GPU是否支持CUDA
CUDA GPUs - Compute Capability | NVIDIA Developerhttps://developer.nvidia.com/cuda-gpus(2)检查是否有显卡驱动
对于GPU而言,需要显卡驱动来调用它,不然它就是一张卡,打开终端(linux下打开终端,windows下打开cmd),输入指令
nvidia-smi
在windows下,如果提示“nvidia -smi显示不是内部或外部指令”,可以先输入【cd C:\Program Files\NVIDIA Corporation\NVSMI】再使用【nvidia-smi】命令。如果查询到Driver Version,说明已经安装了显卡驱动。
如果没有安装的话,就根据显卡型号在官网选择合适的显卡驱动,下载即可
(3)检查是否安装cuda
nvcc -V
如果出现找不到命令,则证明未安装cuda
进入官网,安装cuda,其中版本的选择要考虑两个因素:
- cuda有对应的最低显卡驱动版本
-
cuda与pytorch版本有对应关系
CUDA Toolkit Archive | NVIDIA Developerhttps://developer.nvidia.com/cuda-toolkit-archive
在安装时,请记住一下安装位置,这在后面配置环境变量会用到。
安装完成之后需要配置环境变量。
在windows下载时,右键此电脑->打开属性->高级系统设置,然后在环境变量里添加路径。
在linux下,通过vim ~/.bashrc修改.bashrc 文件,然后通过 source ~/.bashrc 更新变量。
然后命令行中键入 nvcc -V 看一下是否安装成功
(4)安装cudnn*(可选)
cuDNN是NVIDIA的深度神经网络库,可以加速深度学习任务。在官网下载适用于你的cuda版本的cuDNN,解压之后,它是三个文件夹和一个.txt,将三个文件夹拷贝到cuda的安装目录
CUDA Deep Neural Network (cuDNN) | NVIDIA Developerhttps://developer.nvidia.com/cudnn
1.2.2 安装Anaconda
Anaconda 是用于环境管理的,我们可以通过它安装python以及各种库,由于不同项目需要的python版本和库版本都不同,所以可能会发生版本冲突,因此我们可以通过Anaconda为项目创建一个虚拟环境,将不同项目隔离开来。
其在官网下载即可。
Free Download | Anacondahttps://www.anaconda.com/download/下载完成之后,我们打开终端,通过cd指令来到项目路径下,通过conda指令创建虚拟环境,然后通过conda activate激活该虚拟环境,例如在README.md中,我们可以看到官方给出的指令,就是创建了一个名为py36-sp的虚拟环境,并激活该虚拟环境。
conda create --name py36-sp python=3.6
conda activate py36-sp
创建虚拟环境之后,我们可以在该虚拟环境下安装我们需要的库,对于项目而言,所需要的库和相应版本都会写在requirements.txt中,所以只需要通过在终端输入pip install -r requirements.txt即可下载所有需要的库,这里给出了两个,因为第一个文件里给出的框架是tensorflow,作者给出了tensorflow和pytorch两种框架的实现,所以有两个文件。
pip install -r requirements.txt
pip install -r requirements_torch.txt # install pytorch
1.3 代码的运行
在README.md中,基本都会给出项目是怎么运行的,我们以其给出的第一个运行为例,来讲讲如何去快速看算法的实现逻辑。首先根据README.md,我们知道这个运行是在训练模型,由于这个项目里有很多数据集和网络模型,所以我们得先确定它用的是哪一个数据集和模型,以及它的训练过程是怎么样的。
1.3.1 从命令行找到执行文件,定位执行文件主函数
首先是运行指令,python train4.py表示执行该py文件,所以我们可以直接找到该.py文件中的主函数部分
if __name__ == '__main__':
# global var
torch.set_default_tensor_type(torch.FloatTensor)
logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
# add parser
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='command')
# Training command
p_train = subparsers.add_parser('train_base')
p_train.add_argument('config', type=str)
p_train.add_argument('exper_name', type=str)
p_train.add_argument('--eval', action='store_true')
p_train.add_argument('--debug', action='store_true', default=False,
help='turn on debuging mode')
p_train.set_defaults(func=train_base)
# Training command
p_train = subparsers.add_parser('train_joint')
p_train.add_argument('config', type=str)
p_train.add_argument('exper_name', type=str)
p_train.add_argument('--eval', action='store_true')
p_train.add_argument('--debug', action='store_true', default=False,
help='turn on debuging mode')
p_train.set_defaults(func=train_joint)
args = parser.parse_args()
if args.debug:
logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', level=logging.DEBUG)
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
# EXPER_PATH from settings.py
output_dir = os.path.join(EXPER_PATH, args.exper_name)
os.makedirs(output_dir, exist_ok=True)
# with capture_outputs(os.path.join(output_dir, 'log')):
logging.info('Running command {}'.format(args.command.upper()))
args.func(config, output_dir, args)
对于第一次接触pytorch或者不熟悉python的同学来说,这样的写法可能比较陌生,我们逐部分来进行讲解
(1)设置数据类型
这句代码的含义就是设置默认数据类型为torch.FloatTensor
torch.set_default_tensor_type(torch.FloatTensor)
这里简单介绍一下Tensor
Tensor,张量,是包含单一数据类型元素的多维数组,其出现的目的就是为了描述和处理高维数据
(2)设置日志
在项目中,通过日志来记录运行情况,对于查错和运维都有很大的帮助,通过日志,我们可以清楚在运行过程中发生了哪些事件。
我们会首先创建一个日志,然后在日志中添加日志信息来记录程序运行情况,对于日志信息,我们根据其重要程度对其进行分类,其数值越高,表示其级别越高
NOTSET | 0 | 其它 |
DEBUG | 10 | 调试信息,一般给开发人员看的,用于判断程序中间结果是否正确 |
INFO | 20 | 一般信息,程序正常运行时输出的日志信息,确认程序按照预期运行 |
WARNING | 30 | 警告信息,表示虽然程序继续运行,但是接下来可能会出现问题 |
ERROR | 40 | 错误信息,程序发生了错误 |
CRITICAL | 50 | 严重错误,发生了严重错误,程序无法运行 |
logging就是是Python内置的日志记录模块,通过这个模块,我们可以创建日志,添加日志信息。比如logging.error("ERROR")就表示在日志中添加一条错误日志,其信息为ERROR,这里可以看到,日志的作用是用于记录,而为什么使用日志而不用printf来打印呢,这是因为日志是可以设置等级的,其等级有DEBUG,INFO,WARNING,ERROR,CRITICAL,只有日志信息大于等于日志等级,该信息才会被显示。
举个例子,在代码调试中,我们可以通过在代码里添加printf的方式打印中间结果来看程序运行的如何,但是当我们调试完成之后,这些调试信息对我们而言就失去用处了,如果一个个删除,对于工程量很大的代码来说是很麻烦的,而且如果以后需要再次对代码进行调试,又要将其添加回来,而如果使用日志的形式,只需要将日志的等级设置的大于DEBUG,这些调试信息就不会打印出来。当然,日志还有其它优势,这里就不继续扩展。
这句代码的含义就是创建了一个日志,并规定了其日志格式format和日志等级level
logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
(3)设置解释器
argparse是用来解析命令行参数的库,对于该网络,如果我想更换一个数据集,我希望我可以直接通过运行指令中更换输入实现,而不用去更改代码,而通过argparse,我们就可实现这点。对于开发人员而言,可以通过设定参数及其格式,可以一次从命令行里获取运行需要的所有参数并自动将参数值转为程序设定的格式,非常便于开发。对于用户而言,我只需要更换参数就可以获得我想要的结果,而不需要知道其内部实现,对用户来说也更容易上手。
- parser = argparse.ArgumentParse() 创建解析对象
- parser.add_subparsers() 添加子命令
- parser.add_argument() 添加参数
- parser.set_defaults() 设置默认值
- args = parser.parse_args() 将参数返回给实例args,通过agrs.属性来获得属性的值
在这段代码中,我们首先通过ArgumentParse类创建一个解析器对象parser,然后我们通过add_subparsers()来创建子命令,通过这种方法,可以是我们的程序支持多种功能,这里就支持了两个功能,'train_base'和‘train_joint',用户可以通过命令行的参数来确定使用哪个功能
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='command')
# Training command
p_train = subparsers.add_parser('train_base')
...
# Training command
p_train = subparsers.add_parser('train_joint')
...
接着为不同功能设定参数,通过add_argument()来添加参数,'config'表示参数名称,type表示参数类型,如果有default值,则表示如果该参数不输入,则会默认使用该值。
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='command')
# Training command
p_train = subparsers.add_parser('train_base')
p_train.add_argument('config', type=str)
p_train.add_argument('exper_name', type=str)
p_train.add_argument('--eval', action='store_true')
p_train.add_argument('--debug', action='store_true', default=False,
help='turn on debuging mode')
p_train.set_defaults(func=train_base)
args = parser.parse_args()
args.func(config, output_dir, args)
对于这行命令,经过解析器,其选择train_base功能,config参数值为configs/magicpoint_shapes_pair.yaml,exper_name参数为magicpoint_synth
python train4.py train_base configs/magicpoint_shapes_pair.yaml magicpoint_synth --eval
这里我们就知道这行命令的配置文件是什么了,找到这个路径下的配置文件,文件里记录着配置信息,包括使用到的数据集和采用的模型等信息。数据集用的是SyntheticDataset_gaussian,模型用的是SuperPointNet_gauss2。
data:
# name: 'synthetic_shapes'
dataset: 'SyntheticDataset_gaussian'
primitives: 'all'
truncate: {draw_ellipses: 0.3, draw_stripes: 0.2, gaussian_noise: 0.1}
cache_in_memory: true
suffix: 'v6'
add_augmentation_to_test_set: false # set to true to evaluate with noise
gaussian_label:
enable: false
params:
GaussianBlur: {sigma: 0.2}
preprocessing: ## didn't do this
blur_size: 21
resize: [120, 160]
augmentation:
photometric:
enable: true ## for class to recognize
enable_train: true
enable_val: false
primitives: [
'random_brightness', 'random_contrast', 'additive_speckle_noise',
'additive_gaussian_noise', 'additive_shade', 'motion_blur' ]
params:
random_brightness: {max_abs_change: 75}
random_contrast: {strength_range: [0.3, 1.8]}
additive_gaussian_noise: {stddev_range: [0, 15]}
additive_speckle_noise: {prob_range: [0, 0.0035]}
additive_shade:
transparency_range: [-0.5, 0.8]
kernel_size_range: [50, 100]
motion_blur: {max_kernel_size: 7} # origin 7
homographic:
enable: true
enable_train: true
enable_val: false
params:
translation: true
rotation: true
scaling: true
perspective: true
scaling_amplitude: 0.2
perspective_amplitude_x: 0.2
perspective_amplitude_y: 0.2
patch_ratio: 0.8
max_angle: 1.57 # 3.14
allow_artifacts: true
translation_overflow: 0.05
valid_border_margin: 2
warped_pair:
enable: false # false when training only on detector
params:
translation: true
rotation: true
scaling: true
perspective: true
scaling_amplitude: 0.2
perspective_amplitude_x: 0.2
perspective_amplitude_y: 0.2
patch_ratio: 0.85
max_angle: 1.57
allow_artifacts: true # true
valid_border_margin: 3
front_end_model: 'Train_model_heatmap' # 'Train_model_frontend'
model:
name: 'SuperPointNet_gauss2'
params: {
}
detector_loss:
loss_type: 'softmax'
batch_size: 64 # 64
eval_batch_size: 16
learning_rate: 0.001
kernel_reg: 0.
detection_threshold: 0.001 # 1/65
nms: 4
lambda_loss: 0 # disable descriptor loss
dense_loss:
enable: false
params:
descriptor_dist: 4 # 4, 7.5
lambda_d: 800 # 800
sparse_loss:
enable: true
params:
num_matching_attempts: 1000
num_masked_non_matches_per_match: 100
lamda_d: 1
dist: 'cos'
method: '2d'
other_settings: 'train 2d, gauss 0.5'
retrain: True # set true for new model
reset_iter: True
train_iter: 200000 # 200000
tensorboard_interval: 1000 # 200
save_interval: 2000 # 2000
validation_interval: 1000 # one validation of entire val set every N training steps
validation_size: 10
train_show_interval: 1000 # one show of the current training from to Tensorboard every N training steps
seed: 0
我们接着回到train4.py的主函数,前面讲了通过add_subparsers()来创建子命令,通过add_argument()来添加参数,接着通过set_defaults来设置默认值,设定执行函数为train_base,这个执行函数就是训练过程。所以你只需要看train_base函数是怎么执行的,你就知道他这个模型是如何训练的。
# add parser
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='command')
# Training command
p_train = subparsers.add_parser('train_base')
p_train.add_argument('config', type=str)
p_train.add_argument('exper_name', type=str)
p_train.add_argument('--eval', action='store_true')
p_train.add_argument('--debug', action='store_true', default=False,
help='turn on debuging mode')
p_train.set_defaults(func=train_base)
args = parser.parse_args()
1.3.2 训练过程函数
根据主函数,我们找到了执行函数是train_base,通过ctrl+单击,我们可以直接跳转到该函数位置。这里train_base实际执行的是train_joint函数,下面对train_joint函数进行分析
def train_base(config, output_dir, args):
return train_joint(config, output_dir, args)
pass
对于一个训练过程,其步骤大体可以分为
- 数据集加载与划分
- 网络模型设定与训练
def train_joint(config, output_dir, args):
assert 'train_iter' in config
torch.set_default_tensor_type(torch.FloatTensor)
task = config['data']['dataset']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info('train on device: %s', device)
with open(os.path.join(output_dir, 'config.yml'), 'w') as f:
yaml.dump(config, f, default_flow_style=False)
writer = SummaryWriter(getWriterPath(task=args.command,
exper_name=args.exper_name, date=True))
save_path = get_save_path(output_dir)
# data loading
# data = dataLoader(config, dataset='syn', warp_input=True)
data = dataLoader(config, dataset=task, warp_input=True)
train_loader, val_loader = data['train_loader'], data['val_loader']
datasize(train_loader, config, tag='train')
datasize(val_loader, config, tag='val')
# init the training agent using config file
# from train_model_frontend import Train_model_frontend
from utils.loader import get_module
train_model_frontend = get_module('', config['front_end_model'])
train_agent = train_model_frontend(config, save_path=save_path, device=device)
# writer from tensorboard
train_agent.writer = writer
# feed the data into the agent
train_agent.train_loader = train_loader
train_agent.val_loader = val_loader
# load model initiates the model and load the pretrained model (if any)
train_agent.loadModel()
train_agent.dataParallel()
try:
# train function takes care of training and evaluation
train_agent.train()
except KeyboardInterrupt:
print ("press ctrl + c, save model!")
train_agent.saveModel()
pass
(1)数据集加载与划分
这里我们可以知道,数据是通过dataLoader()得到,所以如果我想知道数据是如何进行处理的,我就可以通过ctrl+单击来看细节。
def dataLoader(config, dataset='syn', warp_input=False, train=True, val=True):
import torchvision.transforms as transforms
training_params = config.get('training', {})
workers_train = training_params.get('workers_train', 1) # 16
workers_val = training_params.get('workers_val', 1) # 16
logging.info(f"workers_train: {workers_train}, workers_val: {workers_val}")
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
]),
'val': transforms.Compose([
transforms.ToTensor(),
]),
}
# if dataset == 'syn':
# from datasets.SyntheticDataset_gaussian import SyntheticDataset as Dataset
# else:
Dataset = get_module('datasets', dataset)
print(f"dataset: {dataset}")
train_set = Dataset(
transform=data_transforms['train'],
task = 'train',
**config['data'],
)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=config['model']['batch_size'], shuffle=True,
pin_memory=True,
num_workers=workers_train,
worker_init_fn=worker_init_fn
)
val_set = Dataset(
transform=data_transforms['train'],
task = 'val',
**config['data'],
)
val_loader = torch.utils.data.DataLoader(
val_set, batch_size=config['model']['eval_batch_size'], shuffle=True,
pin_memory=True,
num_workers=workers_val,
worker_init_fn=worker_init_fn
)
# val_set, val_loader = None, None
return {'train_loader': train_loader, 'val_loader': val_loader,
'train_set': train_set, 'val_set': val_set}
首先是通过get()函数得到config中workers_train和workers_val的值,其表示的含义是训练和测试的线程数,并记录到日志中
training_params = config.get('training', {})
workers_train = training_params.get('workers_train', 1) # 16
workers_val = training_params.get('workers_val', 1) # 16
logging.info(f"workers_train: {workers_train}, workers_val: {workers_val}")
接下来是进行数据转换。对于一张彩色数字图片,我们通常会将它表成一个H×W×C的3维矩阵。其中,H表示图片的宽,W表示图片的高,C表示图片的通道数。举个例子,对于一张224×224大小的图I,I[i][j]表示一个像素点,每个像素点的值都是一个C维的向量,对于C=1的黑白图像而言,其值的范围为[0,255],而对于C=3的RGB图像而言其值可能为[23,43,231],每一位的值的范围在[0,255]
transforms.ToTensor()会将PIL和numpy格式的数据从[0,255]范围转换到[0,1] ,具体做法其实就是将原始数据除以255,并将数据的shape从(H x W x C)变为(C x H x W),将数据转为tensor的形式,便于后续的计算
import torchvision.transforms as transforms
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
]),
'val': transforms.Compose([
transforms.ToTensor(),
]),
}
根据数据集的名称获取数据集
Dataset = get_module('datasets', dataset)
print(f"dataset: {dataset}")
train_set = Dataset(
transform=data_transforms['train'],
task = 'train',
**config['data'],
)
val_set = Dataset(
transform=data_transforms['train'],
task = 'val',
**config['data'],
)
DataLoader是数据加载器,由于我们在训练的时候,通常会把数据划分为若干块,一小批一小批的进行训练,所以DataLoader结合了数据集和取样器,每次调用抛出一批数据,直到数据全部抛出之后,如果设定shuffle=True,表示会打乱数据位置,那下一次迭代中得到的DataLoader抛出的批数据与之前不同。比如对于数据x为[1,2,3,4,5,6,7,8,9,10],设置batch_size=5,看到在step0和step1时,抛出的两组批数据batch_x大小都为5,且两组数据交集为空,并集为全集,由于设定shuffle=True,在第二轮迭代里,所得的batch_x不相同。
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=config['model']['batch_size'], shuffle=True,
pin_memory=True,
num_workers=workers_train,
worker_init_fn=worker_init_fn
)
val_loader = torch.utils.data.DataLoader(
val_set, batch_size=config['model']['eval_batch_size'], shuffle=True,
pin_memory=True,
num_workers=workers_val,
worker_init_fn=worker_init_fn
)
通过dataLoader()得到数据后,然后划分为训练集和测试集,datasize通过访问知道它是作者编写的一个生成日志信息的函数。
# data loading
# data = dataLoader(config, dataset='syn', warp_input=True)
data = dataLoader(config, dataset=task, warp_input=True)
train_loader, val_loader = data['train_loader'], data['val_loader']
datasize(train_loader, config, tag='train')
datasize(val_loader, config, tag='val')
(2)网络模型设定与训练
这一步关键就是要找agent的网络架构是什么,以及它如何进行训练的,损失函数是什么。所以关键是去看train_agent是怎么来的。
# init the training agent using config file
# from train_model_frontend import Train_model_frontend
from utils.loader import get_module
train_model_frontend = get_module('', config['front_end_model'])
train_agent = train_model_frontend(config, save_path=save_path, device=device)
# writer from tensorboard
train_agent.writer = writer
# feed the data into the agent
train_agent.train_loader = train_loader
train_agent.val_loader = val_loader
# load model initiates the model and load the pretrained model (if any)
train_agent.loadModel()
train_agent.dataParallel()
try:
# train function takes care of training and evaluation
train_agent.train()
except KeyboardInterrupt:
print ("press ctrl + c, save model!")
train_agent.saveModel()
pass
train_agent是由train_model_frontend而来,而train_model_frontend是通过get_module而来,这个get_module是作者定义的一个方法,通过命名其实可以猜测它是用来获取模块的,我们访问这个函数,其实现逻辑就是获取指定路径下的模块,如果路径为空,那就说明该模块在当前路径下,可以直接通过名字获取。通过代码我们可以知道,它要获取的模块是config['front_end_model'],我们前面已经知道它的config是configs/magicpoint_shapes_pair.yaml,所以可以获得这个模块是Train_model_heatmap
def get_module(path, name):
import importlib
if path == '':
mod = importlib.import_module(name)
else:
mod = importlib.import_module('{}.{}'.format(path, name))
return getattr(mod, name)
1.3.3 网络模型分析
接着我们对Train_model_heatmap进行分析,首先找到Train_model_heatmap.py文件,发现这是一个类,所以train_agent是Train_model_heatmap类的一个实例,所以上述代码就变为一个类的初始化,属性的赋值,以及类的功能调用三个部分,由于我们关注的重点在于网络模型,所以我们先去找网络模型在哪定义,一般会在init函数或者命名包含model的函数中。
train_agent = train_model_frontend(config, save_path=save_path, device=device)
train_agent.writer = writer
train_agent.train_loader = train_loader
train_agent.val_loader = val_loader
train_agent.loadModel()
train_agent.dataParallel()
train_agent.train()
train_agent.saveModel()
当我们进入Train_model_heatmap类时,在init函数中,我们发现没有对网络模型的定义,所以回顾上述代码,推断模型是在loadModel()函数中定义的,然后训练过程在train()函数中定义,但是该类中并没有这两个函数。但为什么运行时并不报错呢,这是因为这个类的声明如下
class Train_model_heatmap(Train_model_frontend)
这表明该类是继承Train_model_frontend类,所以当我们在该类里找不到loadModel()和train()时,要去它的父类,即Train_model_frontend里找该方法。在Train_model_frontend的loadModel()中我们找到model = self.config["model"]["name"],从config文件中我们可以知道网络模型是SuperPointNet_gauss2
对于网络模型,一般会在init函数中定义每一层的结构,确定每一层的输入输出大小
class SuperPointNet_gauss2(torch.nn.Module):
""" Pytorch definition of SuperPoint Network. """
def __init__(self, subpixel_channel=1):
super(SuperPointNet_gauss2, self).__init__()
c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
det_h = 65
self.inc = inconv(1, c1)
self.down1 = down(c1, c2)
self.down2 = down(c2, c3)
self.down3 = down(c3, c4)
self.relu = torch.nn.ReLU(inplace=True)
# self.outc = outconv(64, n_classes)
# Detector Head.
self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
self.bnPa = nn.BatchNorm2d(c5)
self.convPb = torch.nn.Conv2d(c5, det_h, kernel_size=1, stride=1, padding=0)
self.bnPb = nn.BatchNorm2d(det_h)
# Descriptor Head.
self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
self.bnDa = nn.BatchNorm2d(c5)
self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
self.bnDb = nn.BatchNorm2d(d1)
self.output = None
然后在forward函数写出输入x是经过哪些层得到输出的,结合这两个函数,你就能得到网络模型是什么样的。
def forward(self, x):
""" Forward pass that jointly computes unprocessed point and descriptor
tensors.
Input
x: Image pytorch tensor shaped N x 1 x patch_size x patch_size.
Output
semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
"""
# Let's stick to this version: first BN, then relu
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
# Detector Head.
cPa = self.relu(self.bnPa(self.convPa(x4)))
semi = self.bnPb(self.convPb(cPa))
# Descriptor Head.
cDa = self.relu(self.bnDa(self.convDa(x4)))
desc = self.bnDb(self.convDb(cDa))
dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
output = {'semi': semi, 'desc': desc}
self.output = output
return output
接着我们来看训练过程
def train(self, **options):
"""
# outer loop for training
# control training and validation pace
# stop when reaching max iterations
:param options:
:return:
"""
# training info
logging.info("n_iter: %d", self.n_iter)
logging.info("max_iter: %d", self.max_iter)
running_losses = []
epoch = 0
# Train one epoch
while self.n_iter < self.max_iter:
print("epoch: ", epoch)
epoch += 1
for i, sample_train in tqdm(enumerate(self.train_loader)):
# train one sample
loss_out = self.train_val_sample(sample_train, self.n_iter, True)
self.n_iter += 1
running_losses.append(loss_out)
# run validation
if self._eval and self.n_iter % self.config["validation_interval"] == 0:
logging.info("====== Validating...")
for j, sample_val in enumerate(self.val_loader):
self.train_val_sample(sample_val, self.n_iter + j, False)
if j > self.config.get("validation_size", 3):
break
# save model
if self.n_iter % self.config["save_interval"] == 0:
logging.info(
"save model: every %d interval, current iteration: %d",
self.config["save_interval"],
self.n_iter,
)
self.saveModel()
# ending condition
if self.n_iter > self.max_iter:
# end training
logging.info("End training: %d", self.n_iter)
break
pass
其实主要关注的迭代的过程,我们可以忽略这里的日志信息,得到代码的化简如下,在每一轮迭代中,通过enumerate(self.train_loader)得到训练集的批数据,通过self.train_val_sample()函数得到loss_out,当self._eval and self.n_iter % self.config["validation_interval"] == 0,通过enumerate(self.val_loader)得到测试集的批数据,通过self.train_val_sample()函数训练。由此可以推断出,self.train_val_sample()函数应该是用来参数优化和计算loss的
def train(self, **options):
running_losses = []
epoch = 0
while self.n_iter < self.max_iter:
epoch += 1
for i, sample_train in tqdm(enumerate(self.train_loader)):
loss_out = self.train_val_sample(sample_train, self.n_iter, True)
self.n_iter += 1
running_losses.append(loss_out)
if self._eval and self.n_iter % self.config["validation_interval"] == 0:
for j, sample_val in enumerate(self.val_loader):
self.train_val_sample(sample_val, self.n_iter + j, False)
if j > self.config.get("validation_size", 3):
break
# save model
if self.n_iter % self.config["save_interval"] == 0:
self.saveModel()
# ending condition
if self.n_iter > self.max_iter:
# end training
break
pass
附录
1. GPU(显卡),显卡驱动,CUDA,CUDA Toolkit,cudnn,pytorch之间的关系
- Nvidia Driver驱动:操作系统和硬件GPU进行沟通交互的程序,没这个驱动,GPU就是一个摆设,调用不起来,所以得先装这个。装了这个电脑就能用GPU了,但是深度学习搞不定。
- CUDA Toolkit工具包:是基于驱动程序,用来实现GPU并行计算和加速深度学习的软件包。通过这个软件包,调用驱动,实现更加高级的功能。
- cudnn深度学习加速库:是专门针对深度学习的GPU加速库,如果你要使用深度学习框架,这个是必需的。虽然也是CUDA Toolkit的组件,但是官方的下载包里面并没有,需要自己额外下载。
- pytorch,tensorflow深度学习框架:CUDA Toolkit并没有提供深度学习的框架,只是提供了如何高效调用GPU的软件库。如果你要创建深度学习模型,进行训练。还是要使用深度学习框架。
- Nvidia官方下载的CUDA:官方下载的CUDA Toolkit是包含了驱动的,并且联合到一块是称为CUDA的。
- conda下载的CUDA:是不包含驱动的,但是高版本的驱动,是向前兼容的,你可以下载相匹配的cuda toolkit,构成不同版本的cuda