pytorch深度学习代码阅读(0):总论

转载请注明出处
个人博客:https://maxusun.github.io/

这几天进入实验室后,开始阅读实验室学姐的一份代码。刚到手的时候像一个烫手的山芋,不知道从哪里下手,而将整个项目的代码阅读完之后,发现在这个过程中能学到很多东西。于是便想将整个过程记录下来,便于学习。

在阅读前不能傻乎乎的埋头就开始一行一行读。首先应该先对代码进行分块。粗略的说,可以分成四块:

  • 数据加载和预处理
  • 神经网络实现
  • 训练网络
  • 其他(工具,配置文件,过程记录等)

数据加载和预处理

训练神经网络最重要的就是数据。数据预处理包括下面几点:

  1. 实现一个Dataset加载自己的数据
    为了使用DataLoader加载数据。我们需要自己实现一个Dataset对数据处理。对数据的一些预处理,如:裁剪、拉伸、旋转、镜像等处理,都是在这个模块中实现。在看这个模块的时候主要看下面几个函数:

    • __init__(self,prams):初始化函数,prams是用户自己传入的参数,像文件路径啊,裁剪时图片的大小呀之类的。主要是对用到的变量初始化。
    • __getitem__(self, index)DataLoader通过调用Dataset的这个函数来读取数据,其中index是指调用哪个数据。一般在这个函数调用处理数据的函数,返回值是datalabel
    • __len__(self):返回数据集中一共有多少个数据。
  2. 对数据的预处理
    在一些项目中,并不是直接把数据(图片)送到神经网络中训练,通常要对其处理,像裁剪、拉伸、上采样、下采样、随机旋转图片、随机旋转RGB通道、随机镜像等。这些方法一般放在Dataset中实现,或者实现一个data_util专门处理这些。

神经网络实现

神经网络这一块没有什么好说的,一般都是根据论文给的结构实现的。但是一定要多看看一些经典的或者复杂的神经网络是怎样实现的,比如:Resnet,Unet,Fast-RNN等这些不同领域典型网络的实现。

训练网络

训练网络部分的代码也比较复杂,因为这里面涉及到配置加载,断点保存等。看这部分代码的时候,首先找几行重要的代码。在这几行代码之间穿插着作者实现的其他代码。首先我们先找到train函数,如果代码里面有evaluatepredict函数,都是差不多的。

def train(param_set):
    model = OurNet(parms)                               # 1. 初始化模型
    # 这中间的代码一般都是配置文件
    criterion = nn.CrossEntropyLoss()                   # 2. 定义损失函数和优化器
    optimizer = Adam(model.parameters(),lr=5e-4)
    # 这中间如果有代码也是配置文件之类的
    loader = DataLoader(xxxxxxx)                        # 3. 使用DataLoader加载数据
    # 这里有代码也是配置文件或者打印信息之类的
    for epoch in range(num_epochs):                     # 4. 每个epoch算是训练数据一次
        for step,(data1,data2,) in enumerate(loader):  # 5. 每一步都加载一个batch的数据
            # 可能会有是否使用GPU的代码
            outputs = model(data1,data2,)              # 6. 得到输出
            # 这中间的代码一般都是计算、打印、保存 loss、acc、dice等
            # 或者保存当前训练的状态
            optimizer.zero_grad()                       # 7. 梯度归零、反向传播、优化参数等
            loss.backward()
            optimizer.step()
            # 保存model的一些代码
        # 每个epoch结束后计算平均loss、acc、dice等指标并保存

上述代码中,白色代码一般是固定的写法,位置也是类似,可以通过这些白色代码定位,从而对整个代码有全局的理解。

其他

剩下的代码都可以归结到其他中,一般都是工具类、配置文件、过程记录、日志打印等。

  1. 过程记录
    过程记录部分,目前见的比较多的是tensorboardX,使用tensorboardX来记录loss等,或者对网络结构进行可视化
  2. 配置文件
    配置文件方面,目前见到的工具有:fireconfigparser这两个库。
  3. 日志打印
    日志打印一般使用重定向流,讲数据输出到terminal或者log文件中。

总体来说,当拿到一份代码后,首先对其有总体的把握,阅读起来就变得容易许多。同时可以根据自己的需求只阅读不同部分的代码。比如只关注数据的预处理,那就看Dataset相关的部分代码;只关注网络结构,那就看network部分的代码;只是想将代码运行起来,那就只看配置文件和训练网络部分的代码。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MaXuwl

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值