pytorch 全链接网络实现 手写MNIST数据集识别 附带保存图像

该博客详细介绍了如何使用PyTorch构建全连接网络识别MNIST手写数字数据集,并实现参数保存。网络结构包括两层隐藏层,激活函数为ReLU,数据预处理包含归一化。通过DataLoader加载数据,使用Adam优化器和交叉熵损失函数进行训练。代码还展示了如何将模型和参数保存到本地。训练过程中,每50个epoch输出一次损失值。
摘要由CSDN通过智能技术生成

pytorch 全链接网络实现 手写MNIST数据集识别 附带保存图像

mnist数据集图像大小 1* 28 * 28
首先我们确定网络结构:
第一层:784 * 256 + BN层 + RELU激活
第二层:256 * 128 + BN层 + RELU激活
第三层:128* 10
784也就是把28*28,可以理解为把图像数据size为一排输入网络,中间层的256 与128 的设置看情况,最好设置为2的n次方,这样能够方便电脑的计算,bn层的作用是把数据压缩到指定范围,方便损失函数的计算。
接下来依次实现:
1.导包:

import torch
from torchvision import datasets,transforms #dataset数据集
from torch.utils.data import DataLoader     #加载数据
import os

2.在网络类里初始化网络参数:

    def __init__(self):

        super().__init__()            #继承父类--init--
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(784,256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU())
        self.fc2 = torch.nn.Sequential(
            torch.nn.Linear(256,128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU())
        self.fc3 = torch.nn.Linear(128,10)

3.前向过程:

 def forward(self,x):
        #N,C,H,W-->N,V
        x = torch.reshape(x,[x.size(0),-1])
        y=self.fc1(x)#N,256
        y=self.fc2(y)#N,128
        y=self.fc3(y)#N, 10
        return y

4.保存参数与数据加载

 save_params = r"./save_params/paramas.pth"   #保存参数
    save_net = r"./save_params/net.pth"          #保存网络
    transf = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5,],std=[0.5,])])
     #Compose把后面包起来,后面为列表 #totensor做了三步操作轴转换NHWC转为NCHW然后转为tensor再归一化为0-1,Normalize中的mean和
     #std意思为0-1后  再-0.5    再/0.5归一化为【-1-1】
    train_data = datasets.MNIST("./data",train=True,transform=transf,download=True)
    test_data = datasets.MNIST("./data",train=False,transform=transf,download=False) #下载数据集

    trin_loader = DataLoader   (train_data,100,True)
    test_loader = DataLoader(test_data,100,True)  #加载数据集,一次100张,True代表打乱数据。

(可以把数据打出来看看大小方便理解)
5.加载到cuda上:

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")            #cuda存在用cuda 运算

6.定义损失函数与优化器:

   loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵   要求参数是没做one-hot的,该函数有个功能自动做onehot
    #loss_fn 也有个call方法,可以直接传参。
    # torch.nn.MSELoss()#均方差
    # torch.nn.BCELoss()#二分类交叉熵
    # optim = torch.optim.SGD(net.parameters(),lr=1e-3)
    optim = torch.optim.Adam(net.parameters(),lr=1e-3)

7.训练与测试

 net.train()
    for epoch in range(1):
        for i ,(x,y) in enumerate(trin_loader):
            x = x.to(device)
            y = y.to(device)
            out = net(x)#前向输出
            loss = loss_fn(out,y)#求损失
            optim.zero_grad()#清空当前梯度
            loss.backward()#计算当前梯度
            optim.step()#沿着当前梯度更新一步
            if i%50==0:
                print("loss",loss.item())

    net.eval()
    for i,(x,y) in enumerate(test_loader):
        x = x.to(device)
        y = y.to(device)
        out = net(x)
        loss = loss_fn(out,y)
        print()

全部代码:

import torch
from torchvision import datasets,transforms #dataset数据集
from torch.utils.data import DataLoader     #加载数据
import os
class Net(torch.nn.Module):           #定义网络类,继承nn.module主类
    def __init__(self):

        super().__init__()            #继承父类--init--
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(784,256),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU())
        self.fc2 = torch.nn.Sequential(
            torch.nn.Linear(256,128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU())
        self.fc3 = torch.nn.Linear(128,10)
    def forward(self,x):
        #N,C,H,W-->N,V
        x = torch.reshape(x,[x.size(0),-1])
        y=self.fc1(x)#N,256
        y=self.fc2(y)#N,128
        y=self.fc3(y)#N, 10
        return y
if __name__ == '__main__':
    save_params = r"./save_params/paramas.pth"   #保存参数
    save_net = r"./save_params/net.pth"          #保存网络
    transf = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5,],std=[0.5,])])
     #Compose把后面包起来,后面为列表 #totensor做了三步操作轴转换NHWC转为NCHW然后转为tensor再归一化为0-1,Normalize中的mean和
     #std意思为0-1后  再-0.5    再/0.5归一化为【-1-1】
    train_data = datasets.MNIST("./data",train=True,transform=transf,download=True)
    test_data = datasets.MNIST("./data",train=False,transform=transf,download=False) #下载数据集

    trin_loader = DataLoader   (train_data,100,True)
    test_loader = DataLoader(test_data,100,True)  #加载数据集,一次100张,True代表打乱数据。
    # print(train_data.data.shape)  #60000,1,28,28
    # print(train_data.targets.shape)#60000
    # print(trin_loader.shape)
    # print(test_data.targets.shape)
    # print(test_data.classes)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")            #cuda存在用cuda 运算

    net=Net().to(device)         #把网络放到cuda中去运算
    if os.path.exists(save_params):
        net.load_state_dict(torch.load(save_params))  #参数存在加载参数
        print("参数加载成功")
    else:
        print("No params!")
    # net = torch.load(save_net).to(device)#加载参数和网络
    loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵   要求参数是没做one-hot的,该函数有个功能自动做onehot
    #loss_fn 也有个call方法,可以直接传参。
    # torch.nn.MSELoss()#均方差
    # torch.nn.BCELoss()#二分类交叉熵
    # optim = torch.optim.SGD(net.parameters(),lr=1e-3)
    optim = torch.optim.Adam(net.parameters(),lr=1e-3)

    net.train()
    for epoch in range(1):
        for i ,(x,y) in enumerate(trin_loader):
            x = x.to(device)
            y = y.to(device)
            out = net(x)#前向输出
            loss = loss_fn(out,y)#求损失
            optim.zero_grad()#清空当前梯度
            loss.backward()#计算当前梯度
            optim.step()#沿着当前梯度更新一步
            if i%50==0:
                print("loss",loss.item())

    net.eval()
    for i,(x,y) in enumerate(test_loader):
        x = x.to(device)
        y = y.to(device)
        out = net(x)
        loss = loss_fn(out,y)
        print()

    if not os.path.exists("./save_params"):
        os.mkdir("./save_params")
    torch.save(net.state_dict(),"./save_params/parmas.pth")#只保存参数
    torch.save(net,"./save_params/net.pth")

以上就是整个网络的实现过程,识别精度能够达到98-99
tips:
1. 需要注意的是在使用多分类交叉熵损失函数的时候,输出函数就不需要用softmax且也不需要对输出做onehot处理,计算损失时会自动进行softmax与onehot,如果损失函数用mse的化就才需要进行这2 项操作。不然造成的结果就是损失无法快速下降
2.数据集下载的很快,可以自己下也可以用我这种直接下载。
做的很好的盆友可以尝试下用全连接实现cifer数据集的识别要求精度达到60+。
最后是保存训练时的图像代码:

    def img_save(x,i):

        img=transforms.ToPILImage()(x[0100-1])
        
        img.save(r"D:\python projects\after10.30\11.24 deep learning\save_image/%d.jpg"%i )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值