pytorch2-基于minst数据集的简单全连接网络

该博客使用PyTorch构建了一个简单的神经网络模型,用于训练MNIST数据集。首先,定义了数据预处理步骤,然后加载并划分了训练集、验证集和测试集。接下来,创建了网络结构,并使用SGD优化器进行训练。在每个epoch结束后,分别报告了训练损失和验证损失。最后,展示了模型在测试集上的准确率。
摘要由CSDN通过智能技术生成
# coding = utf-8
import numpy as np
import torch
from torchvision import transforms

_task = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5], [0.5]
    )
])

from torchvision.datasets import MNIST

# 数据集加载
mnist = MNIST('./data', download=True, train=True, transform=_task)

# 训练集和验证集划分
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

# create training and validation split
index_list = list(range(len(mnist)))

split_train = int(0.8 * len(mnist))
split_valid = int(0.9 * len(mnist))

train_idx, valid_idx, test_idx = index_list[:split_train], index_list[split_train:split_valid], index_list[split_valid:]

# create sampler objects using SubsetRandomSampler
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)

# create iterator objects for train and valid dataset
trainloader = DataLoader(mnist, batch_size=256, sampler=train_sampler)
validloader = DataLoader(mnist, batch_size=256, sampler=valid_sampler)
test_loader = DataLoader(mnist, batch_size=256, sampler=test_sampler)

# design for net
import torch.nn.functional as F


class NetModel(torch.nn.Module):
    def __init__(self):
        super(NetModel, self).__init__()
        self.hidden = torch.nn.Linear(28 * 28, 300)
        self.output = torch.nn.Linear(300, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.hidden(x)
        x = F.relu(x)
        x = self.output(x)
        return x


if __name__ == "__main__":
    net = NetModel()

    from torch import optim

    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-6, momentum=0.9, nesterov=True)

    for epoch in range(1, 12):
        train_loss, valid_loss = [], []
        # net.train()
        for data, target in trainloader:
            optimizer.zero_grad()
            # forward propagation
            output = net(data)
            loss = loss_function(output, target)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        # net.eval()
        for data, target in validloader:
            output = net(data)
            loss = loss_function(output, target)
            valid_loss.append(loss.item())
        print("Epoch:", epoch, "Training Loss:", np.mean(train_loss), "Valid Loss:", np.mean(valid_loss))

    print("testing ... ")
    total = 0
    correct = 0
    for i, test_data in enumerate(test_loader, 0):
        data, label = test_data
        output = net(data)
        _, predict = torch.max(output.data, 1)

        total += label.size(0)
        correct += np.squeeze((predict == label).sum().numpy())
    print("Accuracy:", (correct / total) * 100, "%")

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值