从图像超分辨率快速入门pytorch

本文旨在帮助Pytorch初学者快速理解深度学习训练过程,以图像超分辨率任务为例,介绍网络模型、数据加载、损失函数和优化器四个关键要素。通过实例代码,展示如何定义网络模型、自定义数据集以及进行训练和优化。
摘要由CSDN通过智能技术生成

前言

最近又开始把pytorch拾起来,学习了github上一些项目之后,发现每个人都会用不同的方式来写深度学习的训练代码,而这些代码对于初学者来说是难以阅读的,因为关键和非关键代码糅杂在一起,让那些需要快速将代码跑起来的初学者摸不着头脑。

所以,本文打算从最基本的出发,只写关键代码,将完成一次深度学习训练需要哪些要素展现给各位初学者,以便你们能够快速上手。等到能够将自己的想法用最简洁的方式写出来并运行起来之后,再对自己的代码进行重构、扩展。我认为这种学习方式是较好的循序渐进的学习方式。

本文选择超分辨率作为入门案例,一是因为通过结合案例能够对训练中涉及到的东西有较好的体会,二是超分辨率是较为简单的任务,我们本次教程的目的是教会大家如何使用pytorch,所以不应该将难度设置在任务本身上。下面开始正文。。。

正文

单一图像超分辨率(SISR)

简单介绍一下图像超分辨率这一任务:超分辨率的任务就是将一张图像的尺寸放大并且要求失真越小越好,举例来说,我们需要将一张256*500的图像放大2倍,那么放大后的图像尺寸就应该是512*1000。用深度学习的方法,我们通常会先将图像缩小成原来的1/2,然后以原始图像作为标签,进行训练。训练的目标是让缩小后的图像放大2倍后与原图越近越好。所以通常会用L1或者L2作为损失函数。

训练4要素

一次训练要想完成,需要的要素我总结为4点:

  • 网络模型
  • 数据
  • 损失函数
  • 优化器

这4个对象都是一次训练必不可少的,通常情况下,需要我们自定义的是前两个:网络模型和数据,而后面两个较为统一,而且pytorch也提供了非常全面的实现供我们使用,它们分别在torch.nn包和torch.optim包下面,使用的时候可以到pytorch官网进行查看,后面我们用到的时候还会再次说明。

网络模型

在网络模型和数据两个当中,网络模型是比较简单的,数据加载稍微麻烦些。我们先来看网络模型的定义。自定义的网络模型都必须继承torch.nn.Module这个类,里面有两个方法需要重写:初始化方法__init__(self)forward(self, *input)方法。在初始化方法中一般要写我们需要哪些层(卷积层、全连接层等),而在forward方法中我们需要写这些层的连接方式。举一个通俗的例子,搭积木需要一个个的积木块,这些积木块放在__init__方法中,而规定将这些积木块如何连接起来则是靠forward方法中的内容。

import torch.nn as nn
import torch.nn.functional as F


class VDSR(nn.Module):

    def __init__(self):
        super(VDSR, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv5 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv6 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv7 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv8 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv9 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.
  • 11
    点赞
  • 80
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值