神经网络模型中class的forward函数何时调用_总结深度学习PyTorch神经网络箱使用...

↑ 点击 蓝字  关注极市平台 ebbf4c5964f3151247573d7f7e3bfc1d.gif 来源丨计算机视觉联盟 编辑丨极市平台

极市导读

 

本文介绍了Pytorch神经网络箱的使用,包括核心组件、神经网络实例、构建方法、优化器比较等内容,非常全面。>>加入极市CV技术交流群,走在计算机视觉的最前沿

1 神经网络核心组件

核心组件包括:
  1. 层:神经网络的基本结构,将输入张量转换为输出张量

  2. 模型:层构成的网络

  3. 损失函数:参数学习的目标函数,通过最小化损失函数来学习各种参数

  4. 优化器:如何是损失函数最小

多个层链接一起构成模型或者网络,输入数据通过模型产生预测值,预测值和真实值进行比较得到损失只,优化器利用损失值进行更新权重参数,使得损失值越来越小,循环过程,当损失值达到阈值活着的循环次数叨叨指定次数就结束循环。

2 神经网络实例

如果初学者,建议直接看3,避免运行结果有误。

神经网络工具及相互关系

67fe4df131a00b14d67eb652d6150f1e.png

2.1 背景说明

如何利用神经网络完成对手些数字进行识别? 使用Pytorch内置函数mnist下载数据 利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器 可视化源数据 利用nn工具箱构建神经网络模型 实例化模型,定义损失函数及优化器 训练模型 可视化结果 使用2个隐藏层,每层激活函数为ReLU,最后使用torch.max(out,1)找出张量out最大值对索引作为预测值

c86843be550beb6b817cae2aabc74b17.png

2.2 准备数据

##(1)导入必要的模块import numpy as npimport torch# 导入内置的 mnist数据from torchvision.datasets import mnist# 导入预处理模块import torchvision.transforms as transformsfrom torch.utils.data import DataLoader# 导入nn及优化器import torch.nn.functional as Fimport torch.optim as optimfrom torch import nn## (2) 定义一些超参数
train_batch_size = 64
test_batch_size = 128
learning_rate = 0.01
num_epoches = 20
lr = 0.01
momentum = 0.5## (3) 下载数据并对数据进行预处理# 定义预处理函数,这些预处理依次放在Compose函数中
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])# 下载数据,并对数据进行预处理
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform)# dataloader是一个可迭代的对象,可以使用迭代器一样使用
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
        154f6a8d4fff331a6f7b0fda101417e7.png

2.3 可视化数据源

import matplotlib.pyplot as plt
%matplotlib inline
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
fig = plt.figure()for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
         f89d0edab19011eb64ec19ac961e0642.png

2.4 构建模型

## (1)构建网络
class Net(nn.Module):
"""
使用sequential构建网络,Sequential()函数功能是将网络的层组合一起
"""
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Net, self).__init__()
self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1))
self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d(n_hidden_2))
self.layer3 = nn.Sequential(nn.L
  • 3
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值