pytorch: 学习笔记9, pytorch实现全连接网络(多层感知机)

pytorch实现全连接网络(多层感知机)

网络模型为3层(含输入层):
输入(28×28)784个特征单元(神经元);
隐藏层:256个单元;
输出层:10 (比如sofamax的10分类)。

代码:
import torch
from torch import nn
from torch.nn import init
import torchvision
import torchvision.transforms as transforms
import sys
import time


class FlattenLayer(torch.nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)


def load_data_fashion_mnist(batch_size, root='Datasets/FashionMNIST'):
    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=False, transform=transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=False, transform=transforms.ToTensor())
    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用额外的进程来加速读取数据
    else:
        num_workers = 4
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_iter, test_iter


# 评估
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

# 训练,随机梯度下降
def sgd(params, lr, batch_size):
    for param in params:
        param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data

# 训练
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()

            # 梯度清零
            optimizer.zero_grad()

            l.backward()
            optimizer.step()

            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))


if __name__ == '__main__':
    num_inputs, num_outputs, num_hiddens = 784, 10, 256

    # 定义网络: 输入层(784)  -->  隐藏层(256)  -->  输出层(10)
    net = nn.Sequential(
            FlattenLayer(),
            nn.Linear(num_inputs, num_hiddens),
            nn.ReLU(),
            nn.Linear(num_hiddens, num_outputs),
            )

    # print('len: ', len(list(net.parameters())))
    # print('param: ', net.parameters())
    # print('param list: ', list(net.parameters()))

    # 初始化网络参数
    for params in net.parameters():  # net.parameters() 为各层的网络参数,可迭代
        init.normal_(params, mean=0, std=0.01)

    print('len: ', len(list(net.parameters())))
    print('param init: ', net.parameters())  # <generator object Module.parameters at 0x000001A0B1356620>
    print('param list init: ', list(net.parameters()))

    # 加载数据
    batch_size = 256
    train_iter, test_iter = load_data_fashion_mnist(batch_size)

    loss = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)

    # 训练
    num_epochs = 5
    train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

代码中几点说明:

net.parameters()
net.parameters() 为:网络定义后,网络模型中的各层的参数 ( <generator object Module.parameters at 0x000001A0B1356620> )
可以转化为一个list查看其参数:list(net.parameters())
torch.nn.init.normal_(params, mean=0, std=0.01)
对参数params进行均值为0方差为1的初始化

结果:
len:  4
param init:  <generator object Module.parameters at 0x000001A0B1356620>
param list init:  [Parameter containing:
tensor([[ 0.0096,  0.0091,  0.0036,  ..., -0.0016, -0.0025,  0.0082],
        [ 0.0014, -0.0060, -0.0170,  ..., -0.0087, -0.0124,  0.0029],
        [-0.0002, -0.0053, -0.0205,  ...,  0.0140,  0.0015, -0.0047],
        ...,
        [-0.0067, -0.0192, -0.0026,  ...,  0.0223,  0.0114,  0.0003],
        [ 0.0033,  0.0095,  0.0097,  ...,  0.0015,  0.0030, -0.0138],
        [ 0.0139, -0.0073,  0.0012,  ..., -0.0143,  0.0085,  0.0056]],
       requires_grad=True), Parameter containing:
tensor([-3.3042e-03,  4.7892e-04, -5.6590e-03,  3.2377e-03,  3.5846e-03,
         6.6989e-03,  2.3601e-03, -9.1927e-04,  2.4281e-02, -2.2155e-02,
         5.3163e-03,  3.7543e-03,  4.3089e-03,  1.1811e-02, -4.1673e-03,
        -1.9667e-02,  2.6118e-03, -3.2978e-03,  1.3942e-02, -1.6289e-02,
         9.8179e-03, -2.2531e-02, -1.4156e-02, -1.4382e-03,  1.7384e-02,
        -1.2549e-02,  9.4562e-03, -9.0459e-03,  8.4983e-03, -1.5124e-03,
        -1.4963e-02, -7.0390e-03,  1.0951e-02, -1.6487e-02, -5.2332e-03,
        -5.2680e-03, -1.7785e-03, -1.3423e-03,  1.9302e-03, -4.9111e-03,
         1.7328e-03, -8.0625e-03, -3.9449e-03,  3.6381e-03,  1.1906e-02,
         5.0710e-03,  5.1031e-03,  9.2445e-04,  2.6244e-02, -2.9451e-03,
         9.6235e-03, -2.1532e-03, -1.3756e-02, -2.1489e-03, -1.3318e-02,
         4.8365e-03, -1.0427e-02,  5.2636e-03,  8.1710e-03, -2.8734e-03,
        -4.0999e-03, -3.3395e-03,  9.2141e-03,  1.8420e-02,  2.3903e-03,
         6.3389e-03, -7.1875e-03,  9.3982e-03, -1.6983e-02, -1.9021e-03,
        -6.3871e-03,  6.7952e-03, -1.2235e-02, -1.6785e-02, -6.6447e-03,
         1.2196e-02,  7.3601e-03, -1.5027e-02, -2.6593e-03, -9.6182e-03,
        -8.4485e-03,  2.2411e-02, -7.5373e-03,  3.6415e-02,  2.6785e-03,
         1.9647e-02, -1.4472e-03, -2.1426e-03, -1.0003e-02, -6.0945e-03,
         6.1464e-04,  6.1757e-03,  1.2456e-02,  1.0664e-02,  8.7811e-03,
        -1.9107e-02, -8.5125e-03, -3.2865e-04,  1.0192e-02, -2.4412e-02,
        -2.1226e-02,  1.0242e-02,  4.0445e-03, -3.3238e-03,  4.4551e-04,
         1.7880e-02,  1.4732e-02,  7.4244e-04,  1.5565e-02,  6.3838e-03,
         4.2519e-03,  3.7454e-04,  6.0372e-03,  1.0598e-02,  6.6352e-03,
         9.3732e-03,  7.1993e-03, -8.0230e-03, -2.0376e-02,  1.7323e-03,
         1.5667e-02, -1.0637e-02, -1.9101e-02, -8.6477e-03,  4.6590e-03,
        -4.7290e-03,  1.2458e-02,  1.0215e-02,  1.4719e-02, -3.4490e-03,
        -4.6496e-03,  6.5331e-03, -3.9560e-03, -1.1488e-02, -8.5887e-03,
         1.5083e-02,  1.0957e-02,  1.9015e-02, -2.1299e-03, -8.0287e-03,
        -1.4993e-02, -1.1674e-02,  7.0364e-03, -2.5001e-03, -1.0356e-03,
         5.7498e-03,  5.7233e-04,  7.9161e-04, -6.0469e-03, -2.6913e-03,
         6.7641e-03,  1.8129e-03,  1.5494e-03, -9.7351e-03,  6.8967e-05,
         2.2971e-03, -9.1847e-03, -2.3717e-03, -6.4801e-03,  2.9549e-03,
        -7.2387e-03, -1.6071e-02, -1.1841e-02, -4.3262e-03, -7.4287e-04,
        -1.0381e-02, -1.9941e-02,  1.2515e-02,  1.1387e-02, -3.3133e-03,
         1.3639e-02, -1.9078e-03, -1.5026e-02,  3.7264e-03,  1.2014e-02,
        -8.0367e-03, -3.5969e-02,  6.3780e-03,  3.4895e-03,  1.5735e-02,
        -5.6254e-04, -5.5807e-03,  5.4600e-04, -8.7495e-04,  7.8439e-03,
        -1.2823e-02, -1.4356e-02,  7.8702e-03,  4.3848e-04,  5.3145e-03,
        -6.1489e-03,  8.7027e-04, -1.0802e-03,  7.2241e-03,  5.0439e-03,
         1.3031e-02,  7.4891e-03, -7.3666e-03, -6.0929e-03, -6.1948e-03,
         8.1562e-03, -6.0273e-03, -1.0222e-02, -1.7376e-03, -1.2922e-02,
         1.1247e-02, -1.0559e-02, -1.5887e-02,  1.0038e-02, -1.4515e-02,
        -9.5886e-03,  1.2830e-02,  8.8126e-03, -9.1111e-03,  6.2043e-03,
         1.9829e-02,  1.5241e-02,  2.2486e-03,  9.0140e-03,  1.7259e-02,
        -5.6758e-03,  4.1752e-03,  4.8623e-04,  1.9457e-02,  8.3239e-03,
        -1.1590e-02, -5.5052e-03, -2.0561e-02,  2.8499e-03,  1.1046e-02,
        -7.4051e-03,  1.1231e-02,  1.4840e-02,  4.9973e-03,  1.3801e-02,
        -1.4826e-02, -7.4246e-03, -1.5146e-02,  1.2617e-02,  7.5188e-03,
         1.9418e-02, -1.0118e-03, -8.8281e-03, -5.6416e-03,  1.8890e-04,
        -3.9850e-03, -4.7776e-03,  9.0903e-03, -3.2510e-02,  3.5589e-03,
         3.9693e-03,  1.9995e-02,  2.7695e-03,  9.5730e-03, -9.2412e-03,
         1.0012e-02], requires_grad=True), Parameter containing:
tensor([[-0.0074, -0.0100,  0.0084,  ...,  0.0004, -0.0123, -0.0015],
        [-0.0017,  0.0017,  0.0104,  ..., -0.0067, -0.0016, -0.0096],
        [-0.0132, -0.0034,  0.0193,  ...,  0.0191,  0.0004,  0.0105],
        ...,
        [ 0.0167, -0.0144,  0.0048,  ...,  0.0061, -0.0083,  0.0072],
        [ 0.0018,  0.0048,  0.0050,  ...,  0.0015, -0.0165,  0.0046],
        [-0.0039,  0.0027, -0.0014,  ...,  0.0078, -0.0054, -0.0089]],
       requires_grad=True), Parameter containing:
tensor([-0.0058, -0.0037,  0.0099, -0.0099,  0.0018, -0.0193, -0.0041,  0.0043,
        -0.0114, -0.0049], requires_grad=True)]
epoch 1, loss 0.0031, train acc 0.699, test acc 0.763
epoch 2, loss 0.0019, train acc 0.817, test acc 0.785
epoch 3, loss 0.0017, train acc 0.844, test acc 0.844
epoch 4, loss 0.0015, train acc 0.856, test acc 0.798
epoch 5, loss 0.0014, train acc 0.865, test acc 0.845

Process finished with exit code 0

参考
https://pytorch.org/docs/stable/nn.html#
https://pytorch.org/docs/stable/nn.init.html (torch.nn.init)
https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值