深度学习 — — 入门PyTorch(一)

PyTorch是Facebook开发的AI框架,其最新代码在GitHub进行更新。自2017年以来,它的使用率稳步一直保持稳定增长。相对于TensorFlow框架入门更为简单,也可以很方便的进行网络的构建以完成网络的训练,从而帮助我们很快的复现论文,是一个非常值得学习的框架。

本文主要介绍PyTorch的入门知识,从构建网络模型开始,到如何创建自定义的数据加载器,然后更新网络权重以完成模型的训练。

构建网络

PyTorch提供了一种构建自己模型的标准方法,整个定义应保留在对象中,该对象是nn.Module类的子类。在该类中,一般包含__init__和forward方法,为了更形象的解释如何使用这些基本概念,我在下面给出了一个神经网络模型构建的示例,该网络包含3个全连接层和2个RELU层。

import torch
import torch.nn as nn
import torch.nn.functional as F




class Net(nn.Module):


    def __init__(self):
        super(Net, self).__init__()
        # Defining 3 linear layers but NOT the way they should be connected
        # Receives an array of length 240 and outputs one with length 120
        self.fc1 = nn.Linear(240, 120)
        # Receives an array of length 120 and outputs one with length 60
        self.fc2 = nn.Linear(120, 60)
        # Receives an array of length 60 and outputs one with length 10
        self.fc3 = nn.Linear(60, 10)


    def forward(self, x):
        # Defining the way that the layers of the model should be connected
        # Performs RELU on the output of layer 'self.fc1 = nn.Linear(240, 120)'
        x = F.relu(self.fc1(x))
        # Performs RELU on the output of layer 'self.fc2 = nn.Linear(120, 60)'
        x = F.relu(self.fc2(x))
        # Passes the array through the last linear layer 'self.fc3 = nn.Linear(60, 10)'
        x = self.fc3(x)
        return x




net = Net()

· __init__

与其他Python类一样,__init__方法用于定义类的属性和初始化卷积的一些参数。在PyTorch上下文中,始终调用super()方法来初始化父类。除此之外,还可以定义所有具有可优化参数的网络层,对于网络层的定义不需要按照在网络中使用的顺序,因为此处仅完成对网络层的定义。

· forward

表示网络的前向传播过程,即表示各层如何连接的方法,用来构建网络层的先后运算步骤。从上述示例中可以看到,在其中调用__init__内定义的网络层,然后返回代表网络输出的值。

值得注意的是,在forward方法中应用了一些其他函数,这些函数在__init__方法中未定义,但也可以称作网络层。以F.relu()函数为例,我们没有在__init__方法中定义它,是因为它没有任何可训练的参数。换句话说,如果给F.relu()函数提供相同的输入,它将始终提供相同的输出,网络的训练不会影响其行为。因此,根据经验,可以将没有任何权重更新的网络层放入forward方法中。换而言之,将所有具有权重的网络层放在__init__中。

加载数据:Dataset和DataLoader

Dataset和DataLoader是PyTorch中的两个工具,可以定义如何访问数据,便于读者使用自己的数据完成对模型的训练。下面的代码提供了一个使用简单的Dataset / DataLoader类的示例,说明如何定义自己的Dataset类,并完成数据的加载过程。

import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader


class ExampleDataset(Dataset):
"""Example Dataset"""


def __init__(self, csv_file):
""" 
        csv_file (string): Path to the csv file containing data.
        """
        self.data_frame = pd.read_csv(csv_file)


def __len__(self):
return len(self.data_frame)


def __getitem__(self, idx):
return self.data_frame[idx]


# instantiates the dataset  
example_dataset = ExampleDataset('my_data_file.csv')


# batch size: number of samples returned per iteration
# shuffle: Flag to shuffle the data before reading so you don't read always in the same order
# num_workers: used to load the data in parallel
example_data_loader = DataLoader(example_dataset, , batch_size=4, shuffle=True, num_workers=4)


# Loops over the data 4 samples at a time
for batch_index, batch in enumerate(example_data_loader):
    print(batch_index, batch)

上述Dataset类中使用了3种方法:

· __init__

在初始化过程中,应该输入数据目录信息和其他允许访问的信息。例如上述示例是从csv文件加载数据,也可以使用加载文件名列表,其中每个文件名代表一个数据。注意:在该过程中还未加载数据。

· __len__

该方法用于返回数据集的大小。例如,如果某些目录中有一些图像,则必须实现一种对构成该数据集文件总数进行计数的方法。上述示例中只是获得数据帧的长度。

· __getitem__

该方法用于接收一个索引idx,并返回数据集中对应的数据和标签,是数据加载的核心方法。

为了更有效地加载数据集,我们可以使用DataLoader类。该类可以并行读取一批数据,同时可以选择是否对数据进行重新排序。所有上述操作技巧都可以帮助我们更好地完成模型的训练。

·  END  ·

RECOMMEND

推荐阅读

 1. 效率提升的软件大礼包

 2. 那么多可选编程语言,Why Python?

 3. 学习Python,你选对书了吗?

 4. 90%初学者会混淆的Python概念

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值