softmax回归分类2

博客对比了使用全连接层和多层感知机在FashionMNIST数据集上的训练效果。全连接层在100个训练周期后的测试准确率为82%,而加入ReLU激活函数的多层感知机在相同条件下测试准确率从82%逐步提升到约83%,显示出一定的性能提升。
摘要由CSDN通过智能技术生成
import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append("..")
from utils.SlowFast import SlowFast1
from utils.SlowFast import get_fashion_data

使用全连接层替换slowfast实现

num_inputs = 784
num_outputs = 10
class LinearNet(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(LinearNet, self).__init__()
        self.linear = nn.Linear(num_inputs, num_outputs)
        init.normal_(self.linear.weight, mean = 0, std = 0.01)
        init.constant_(self.linear.bias, val = 0)
    def forward(self, x): # x shape: (batch, 1, 28, 28)
        y = self.linear(x.view(x.shape[0], -1))
        return y
net = LinearNet(num_inputs, num_outputs)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = 0.01)
slowfast1 = SlowFast1(image_size = 28*28, class_nums = 10)
num_epochs = 100
batch_size = 1000
train_iter, test_iter = get_fashion_data(img_dir='./Datasets/FashionMNIST', batch_size = batch_size)
slowfast1.train(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)
epoch 1, loss 0.0018, train acc 0.570, test acc 0.652
epoch 2, loss 0.0014, train acc 0.666, test acc 0.659
epoch 3, loss 0.0012, train acc 0.673, test acc 0.671
epoch 4, loss 0.0011, train acc 0.686, test acc 0.679
epoch 5, loss 0.0010, train acc 0.699, test acc 0.693
epoch 6, loss 0.0009, train acc 0.712, test acc 0.705
epoch 7, loss 0.0009, train acc 0.724, test acc 0.718
epoch 8, loss 0.0009, train acc 0.733, test acc 0.726
epoch 9, loss 0.0008, train acc 0.742, test acc 0.732
epoch 10, loss 0.0008, train acc 0.748, test acc 0.740
epoch 11, loss 0.0008, train acc 0.754, test acc 0.743
epoch 12, loss 0.0008, train acc 0.760, test acc 0.747
epoch 13, loss 0.0008, train acc 0.764, test acc 0.750
epoch 14, loss 0.0007, train acc 0.767, test acc 0.755
epoch 15, loss 0.0007, train acc 0.771, test acc 0.757
epoch 16, loss 0.0007, train acc 0.774, test acc 0.762
epoch 17, loss 0.0007, train acc 0.777, test acc 0.764
epoch 18, loss 0.0007, train acc 0.779, test acc 0.766
epoch 19, loss 0.0007, train acc 0.782, test acc 0.769
epoch 20, loss 0.0007, train acc 0.785, test acc 0.772
epoch 21, loss 0.0007, train acc 0.787, test acc 0.773
epoch 22, loss 0.0007, train acc 0.789, test acc 0.774
epoch 23, loss 0.0007, train acc 0.791, test acc 0.777
epoch 24, loss 0.0007, train acc 0.792, test acc 0.779
epoch 25, loss 0.0007, train acc 0.794, test acc 0.780
epoch 26, loss 0.0007, train acc 0.795, test acc 0.781
epoch 27, loss 0.0006, train acc 0.796, test acc 0.783
epoch 28, loss 0.0006, train acc 0.798, test acc 0.785
epoch 29, loss 0.0006, train acc 0.799, test acc 0.786
epoch 30, loss 0.0006, train acc 0.800, test acc 0.787
epoch 31, loss 0.0006, train acc 0.801, test acc 0.788
epoch 32, loss 0.0006, train acc 0.802, test acc 0.788
epoch 33, loss 0.0006, train acc 0.803, test acc 0.789
epoch 34, loss 0.0006, train acc 0.804, test acc 0.791
epoch 35, loss 0.0006, train acc 0.805, test acc 0.793
epoch 36, loss 0.0006, train acc 0.806, test acc 0.793
epoch 37, loss 0.0006, train acc 0.806, test acc 0.794
epoch 38, loss 0.0006, train acc 0.807, test acc 0.795
epoch 39, loss 0.0006, train acc 0.808, test acc 0.796
epoch 40, loss 0.0006, train acc 0.809, test acc 0.795
epoch 41, loss 0.0006, train acc 0.810, test acc 0.797
epoch 42, loss 0.0006, train acc 0.810, test acc 0.798
epoch 43, loss 0.0006, train acc 0.811, test acc 0.799
epoch 44, loss 0.0006, train acc 0.811, test acc 0.801
epoch 45, loss 0.0006, train acc 0.812, test acc 0.801
epoch 46, loss 0.0006, train acc 0.813, test acc 0.801
epoch 47, loss 0.0006, train acc 0.813, test acc 0.803
epoch 48, loss 0.0006, train acc 0.814, test acc 0.803
epoch 49, loss 0.0006, train acc 0.815, test acc 0.804
epoch 50, loss 0.0006, train acc 0.815, test acc 0.805
epoch 51, loss 0.0006, train acc 0.815, test acc 0.805
epoch 52, loss 0.0006, train acc 0.816, test acc 0.805
epoch 53, loss 0.0006, train acc 0.817, test acc 0.805
epoch 54, loss 0.0006, train acc 0.817, test acc 0.806
epoch 55, loss 0.0006, train acc 0.818, test acc 0.807
epoch 56, loss 0.0006, train acc 0.818, test acc 0.806
epoch 57, loss 0.0006, train acc 0.819, test acc 0.807
epoch 58, loss 0.0006, train acc 0.819, test acc 0.807
epoch 59, loss 0.0006, train acc 0.820, test acc 0.807
epoch 60, loss 0.0006, train acc 0.820, test acc 0.808
epoch 61, loss 0.0006, train acc 0.820, test acc 0.808
epoch 62, loss 0.0006, train acc 0.821, test acc 0.809
epoch 63, loss 0.0006, train acc 0.821, test acc 0.808
epoch 64, loss 0.0006, train acc 0.821, test acc 0.809
epoch 65, loss 0.0006, train acc 0.822, test acc 0.810
epoch 66, loss 0.0006, train acc 0.822, test acc 0.810
epoch 67, loss 0.0005, train acc 0.822, test acc 0.811
epoch 68, loss 0.0005, train acc 0.823, test acc 0.810
epoch 69, loss 0.0005, train acc 0.823, test acc 0.810
epoch 70, loss 0.0005, train acc 0.824, test acc 0.811
epoch 71, loss 0.0005, train acc 0.824, test acc 0.811
epoch 72, loss 0.0005, train acc 0.824, test acc 0.811
epoch 73, loss 0.0005, train acc 0.824, test acc 0.812
epoch 74, loss 0.0005, train acc 0.825, test acc 0.812
epoch 75, loss 0.0005, train acc 0.825, test acc 0.812
epoch 76, loss 0.0005, train acc 0.826, test acc 0.812
epoch 77, loss 0.0005, train acc 0.826, test acc 0.812
epoch 78, loss 0.0005, train acc 0.826, test acc 0.812
epoch 79, loss 0.0005, train acc 0.826, test acc 0.813
epoch 80, loss 0.0005, train acc 0.827, test acc 0.813
epoch 81, loss 0.0005, train acc 0.827, test acc 0.813
epoch 82, loss 0.0005, train acc 0.827, test acc 0.813
epoch 83, loss 0.0005, train acc 0.827, test acc 0.814
epoch 84, loss 0.0005, train acc 0.827, test acc 0.815
epoch 85, loss 0.0005, train acc 0.827, test acc 0.815
epoch 86, loss 0.0005, train acc 0.828, test acc 0.815
epoch 87, loss 0.0005, train acc 0.828, test acc 0.815
epoch 88, loss 0.0005, train acc 0.828, test acc 0.816
epoch 89, loss 0.0005, train acc 0.829, test acc 0.816
epoch 90, loss 0.0005, train acc 0.828, test acc 0.816
epoch 91, loss 0.0005, train acc 0.829, test acc 0.816
epoch 92, loss 0.0005, train acc 0.829, test acc 0.816
epoch 93, loss 0.0005, train acc 0.829, test acc 0.817
epoch 94, loss 0.0005, train acc 0.829, test acc 0.816
epoch 95, loss 0.0005, train acc 0.829, test acc 0.816
epoch 96, loss 0.0005, train acc 0.830, test acc 0.817
epoch 97, loss 0.0005, train acc 0.830, test acc 0.817
epoch 98, loss 0.0005, train acc 0.830, test acc 0.818
epoch 99, loss 0.0005, train acc 0.830, test acc 0.817
epoch 100, loss 0.0005, train acc 0.830, test acc 0.817

准确率跟上一节一致,也是82%左右

使用多层感知机进行训练

class LinearMultiNet(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(LinearMultiNet, self).__init__()
        num_hidden = 50
        self.linear = nn.Linear(num_inputs, num_hidden)
        self.linear2 = nn.Linear(num_hidden, num_outputs)
        self.net = nn.Sequential(
            nn.Linear(num_inputs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_outputs),
            )
        for params in net.parameters():
            init.normal_(params, mean=0, std=0.01)
    def forward(self, x): # x shape: (batch, 1, 28, 28)
        y = self.net(x.view(x.shape[0], -1))
        return y
net = LinearMultiNet(num_inputs, num_outputs)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr = 0.01)
slowfast1 = SlowFast1(image_size = 28*28, class_nums = 10)
num_epochs = 100
batch_size = 1000
train_iter, test_iter = get_fashion_data(img_dir='./Datasets/FashionMNIST', batch_size = batch_size)
slowfast1.train(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)
epoch 1, loss 0.0022, train acc 0.216, test acc 0.327
epoch 2, loss 0.0019, train acc 0.466, test acc 0.592
epoch 3, loss 0.0017, train acc 0.620, test acc 0.623
epoch 4, loss 0.0014, train acc 0.640, test acc 0.642
epoch 5, loss 0.0013, train acc 0.656, test acc 0.655
epoch 6, loss 0.0011, train acc 0.665, test acc 0.660
epoch 7, loss 0.0011, train acc 0.671, test acc 0.664
epoch 8, loss 0.0010, train acc 0.679, test acc 0.676
epoch 9, loss 0.0009, train acc 0.688, test acc 0.684
epoch 10, loss 0.0009, train acc 0.696, test acc 0.688
epoch 11, loss 0.0009, train acc 0.704, test acc 0.696
epoch 12, loss 0.0008, train acc 0.711, test acc 0.702
epoch 13, loss 0.0008, train acc 0.717, test acc 0.709
epoch 14, loss 0.0008, train acc 0.724, test acc 0.715
epoch 15, loss 0.0008, train acc 0.732, test acc 0.719
epoch 16, loss 0.0008, train acc 0.737, test acc 0.728
epoch 17, loss 0.0007, train acc 0.742, test acc 0.732
epoch 18, loss 0.0007, train acc 0.746, test acc 0.737
epoch 19, loss 0.0007, train acc 0.751, test acc 0.741
epoch 20, loss 0.0007, train acc 0.756, test acc 0.747
epoch 21, loss 0.0007, train acc 0.759, test acc 0.751
epoch 22, loss 0.0007, train acc 0.763, test acc 0.755
epoch 23, loss 0.0007, train acc 0.767, test acc 0.757
epoch 24, loss 0.0007, train acc 0.770, test acc 0.761
epoch 25, loss 0.0007, train acc 0.774, test acc 0.765
epoch 26, loss 0.0007, train acc 0.777, test acc 0.768
epoch 27, loss 0.0006, train acc 0.780, test acc 0.771
epoch 28, loss 0.0006, train acc 0.783, test acc 0.775
epoch 29, loss 0.0006, train acc 0.786, test acc 0.777
epoch 30, loss 0.0006, train acc 0.789, test acc 0.779
epoch 31, loss 0.0006, train acc 0.792, test acc 0.782
epoch 32, loss 0.0006, train acc 0.794, test acc 0.784
epoch 33, loss 0.0006, train acc 0.796, test acc 0.786
epoch 34, loss 0.0006, train acc 0.798, test acc 0.788
epoch 35, loss 0.0006, train acc 0.800, test acc 0.791
epoch 36, loss 0.0006, train acc 0.802, test acc 0.792
epoch 37, loss 0.0006, train acc 0.804, test acc 0.794
epoch 38, loss 0.0006, train acc 0.806, test acc 0.794
epoch 39, loss 0.0006, train acc 0.807, test acc 0.798
epoch 40, loss 0.0006, train acc 0.808, test acc 0.798
epoch 41, loss 0.0006, train acc 0.810, test acc 0.798
epoch 42, loss 0.0006, train acc 0.812, test acc 0.800
epoch 43, loss 0.0006, train acc 0.813, test acc 0.801
epoch 44, loss 0.0006, train acc 0.813, test acc 0.803
epoch 45, loss 0.0006, train acc 0.815, test acc 0.803
epoch 46, loss 0.0006, train acc 0.815, test acc 0.804
epoch 47, loss 0.0005, train acc 0.816, test acc 0.805
epoch 48, loss 0.0005, train acc 0.817, test acc 0.807
epoch 49, loss 0.0005, train acc 0.817, test acc 0.808
epoch 50, loss 0.0005, train acc 0.819, test acc 0.808
epoch 51, loss 0.0005, train acc 0.820, test acc 0.810
epoch 52, loss 0.0005, train acc 0.821, test acc 0.810
epoch 53, loss 0.0005, train acc 0.822, test acc 0.810
epoch 54, loss 0.0005, train acc 0.822, test acc 0.811
epoch 55, loss 0.0005, train acc 0.823, test acc 0.812
epoch 56, loss 0.0005, train acc 0.824, test acc 0.814
epoch 57, loss 0.0005, train acc 0.825, test acc 0.814
epoch 58, loss 0.0005, train acc 0.826, test acc 0.815
epoch 59, loss 0.0005, train acc 0.826, test acc 0.816
epoch 60, loss 0.0005, train acc 0.826, test acc 0.815
epoch 61, loss 0.0005, train acc 0.827, test acc 0.816
epoch 62, loss 0.0005, train acc 0.828, test acc 0.817
epoch 63, loss 0.0005, train acc 0.828, test acc 0.817
epoch 64, loss 0.0005, train acc 0.829, test acc 0.818
epoch 65, loss 0.0005, train acc 0.830, test acc 0.819
epoch 66, loss 0.0005, train acc 0.830, test acc 0.819
epoch 67, loss 0.0005, train acc 0.830, test acc 0.820
epoch 68, loss 0.0005, train acc 0.831, test acc 0.820
epoch 69, loss 0.0005, train acc 0.831, test acc 0.820
epoch 70, loss 0.0005, train acc 0.832, test acc 0.820
epoch 71, loss 0.0005, train acc 0.832, test acc 0.821
epoch 72, loss 0.0005, train acc 0.832, test acc 0.820
epoch 73, loss 0.0005, train acc 0.832, test acc 0.822
epoch 74, loss 0.0005, train acc 0.833, test acc 0.822
epoch 75, loss 0.0005, train acc 0.834, test acc 0.822
epoch 76, loss 0.0005, train acc 0.834, test acc 0.821
epoch 77, loss 0.0005, train acc 0.834, test acc 0.822
epoch 78, loss 0.0005, train acc 0.835, test acc 0.822
epoch 79, loss 0.0005, train acc 0.835, test acc 0.823
epoch 80, loss 0.0005, train acc 0.836, test acc 0.823
epoch 81, loss 0.0005, train acc 0.836, test acc 0.822
epoch 82, loss 0.0005, train acc 0.836, test acc 0.823
epoch 83, loss 0.0005, train acc 0.837, test acc 0.824
epoch 84, loss 0.0005, train acc 0.837, test acc 0.824
epoch 85, loss 0.0005, train acc 0.837, test acc 0.824
epoch 86, loss 0.0005, train acc 0.838, test acc 0.824
epoch 87, loss 0.0005, train acc 0.838, test acc 0.825
epoch 88, loss 0.0005, train acc 0.838, test acc 0.825
epoch 89, loss 0.0005, train acc 0.839, test acc 0.825
epoch 90, loss 0.0005, train acc 0.839, test acc 0.825
epoch 91, loss 0.0005, train acc 0.840, test acc 0.826
epoch 92, loss 0.0005, train acc 0.840, test acc 0.827
epoch 93, loss 0.0005, train acc 0.840, test acc 0.827
epoch 94, loss 0.0005, train acc 0.840, test acc 0.827
epoch 95, loss 0.0005, train acc 0.841, test acc 0.828
epoch 96, loss 0.0005, train acc 0.841, test acc 0.828
epoch 97, loss 0.0005, train acc 0.841, test acc 0.828
epoch 98, loss 0.0005, train acc 0.841, test acc 0.829
epoch 99, loss 0.0005, train acc 0.842, test acc 0.828
epoch 100, loss 0.0005, train acc 0.841, test acc 0.829

也是没有多大的提升

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值