动手学深度学习(三)——丢弃法(gluon)

版权声明:博客文章都是作者辛苦整理的,转载请注明出处,谢谢! https://blog.csdn.net/Quincuntial/article/details/79572275

文章作者:Tyan
博客:noahsnail.com  |  CSDN  |  简书

注:本文为李沐大神的《动手学深度学习》的课程笔记!

import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet import autograd
from mxnet.gluon import nn
from utils import load_data_fashion_mnist, accuracy, evaluate_accuracy

定义模型并添加丢弃层

# 定义模型
net = nn.Sequential()
# 丢弃概率
drop_prob1 = 0.2
drop_prob2 = 0.5

# 添加层
with net.name_scope():
    # 将输入数据展开
    net.add(nn.Flatten())
    # 第一个全连接层
    net.add(nn.Dense(256, activation="relu"))
    # 添加丢弃层
    net.add(nn.Dropout(drop_prob1))
    # 第二个全连接层
    net.add(nn.Dense(256, activation="relu"))
    # 添加丢弃层
    net.add(nn.Dropout(drop_prob2))
    # 定义输出层
    net.add(nn.Dense(10))

# 初始化模型参数
net.initialize()

读取数据并训练

# 批数据大小
batch_size = 256

# 加载数据
train_data, test_data = load_data_fashion_mnist(batch_size)

# 优化
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})

# 定义交叉熵损失
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

# 训练
for epoch in range(5):
    # 训练损失
    train_loss = 0.0
    # 训练准确率
    train_acc = 0.0
    # 迭代训练
    for data, label in train_data:
        with autograd.record():
            # 计算输出
            output = net(data)
            # 计算损失
            loss = softmax_cross_entropy(output, label)
        # 梯度反向传播
        loss.backward()
        # 更新梯度
        trainer.step(batch_size)
        # 记录训练损失
        train_loss += nd.mean(loss).asscalar()
        # 记录训练准确率
        train_acc += accuracy(output, label)
    # 计算测试准确率
    test_acc = evaluate_accuracy(test_data, net)
    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
Epoch 0. Loss: 0.817475, Train acc 0.697349, Test acc 0.778145
Epoch 1. Loss: 0.515098, Train acc 0.810831, Test acc 0.847456
Epoch 2. Loss: 0.458402, Train acc 0.833450, Test acc 0.823918
Epoch 3. Loss: 0.419452, Train acc 0.846554, Test acc 0.862079
Epoch 4. Loss: 0.396483, Train acc 0.854067, Test acc 0.874499
阅读更多

扫码向博主提问

SnailTyan

博客专家

非学,无以致疑;非问,无以广识
  • 擅长领域:
  • 深度学习
  • PyTorch
  • OCR
  • Docker
  • Caffe
去开通我的Chat快问
换一批

没有更多推荐了,返回首页