Pytorch代码实战(MINST-FASHION数据集)

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset


import torchvision
import torchvision.transforms as tranfsorms

import matplotlib.pyplot as plt
import numpy
# 确定数据,确定超参数
lr = 0.15
gamma = 0
epochs = 10
bs = 128
# 实例化数据集
mnist = torchvision.datasets.FashionMNIST(root = "D:\jupyterDemo\MINST-FASHION数据集"
                                          ,download = False
                                          ,train = True
                                          ,transform = tranfsorms.ToTensor()
                                         )
# 数据样式探索
batchdata = DataLoader(mnist
                       ,batch_size = bs
                       ,shuffle = True
)

for x,y in batchdata:
    print(x.shape)
    print(y.shape)
    break
# 张量中共有多少个元素,作为输入
input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique())
# 定义神经网络架构
class Model(nn.Module):
    def __init__(self,in_features=10,out_features=2):
        super().__init__()
        self.linear1 = nn.Linear(in_features,128,bias=False)
        self.output = nn.Linear(128,out_features,bias=False)

    def forward(self,x):
        # -1作为占位符,表示pytorch自动帮我们计算-1这个位置的维度应该是多少
        x = x.view(-1,28*28)
        sigma1 = torch.relu(self.linear1(x))
        sigma2 = F.log_softmax(self.output(sigma1),dim=1)
        return sigma2
# 定义训练函数
def fit(net,batchdata,lr=0.15,epochs=5,gamma=0):
    criterion = nn.NLLLoss()
    opt = optim.SGD(net.parameters()         # 优化器:动量法梯度下降
                    ,lr=lr
                    ,momentum=gamma)
    
    correct = 0       # 循环开始之前,预测正确的值为0
    samples = 0       # 循环开始之前,模型一个样本都没有见过
    
    for epoch in range(epochs):              # 全数据共训练几次
        for batch_idx,(x,y) in enumerate(batchdata):
            # 核心代码区 *******************************************
            y = y.view(x.shape[0])
            sigma = net.forward(x)   # 正向传播
            loss = criterion(sigma,y)
            loss.backward()
            opt.step()
            opt.zero_grad()
            # 核心代码区 *******************************************

            
            # 准确率
            yhat = torch.max(sigma,1)[1]      # 即得到预测标签
            correct += torch.sum(yhat == y)
            samples += x.shape[0]
            if(batch_idx+1) % 125 == 0 or batch_idx == len(batchdata) - 1:
                print("Epoch{}:[{}/{}({:.0f}%)] Loss:{:.6f},Accuracy:{:.3f}".format(
                                               epoch+1
                                              ,samples
                                              ,epochs*len(batchdata.dataset)
                                              ,100*samples/(epochs*len(batchdata.dataset))
                                              ,loss.data.item()
                                              ,float(100*correct/samples))
                     )
# 训练与评估
torch.manual_seed(1412)
net = Model(in_features=input_,out_features=output_)
fit(net,batchdata,lr=lr,epochs=epochs,gamma=gamma)

运行结果:

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值