卷积神经网络项目301--图像分类

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
def test01():
    train=CIFAR10(root='root', train=True, download=True, transform=Compose([ToTensor()]))
    valid=CIFAR10(root='root', train=False, download=True, transform=Compose([ToTensor()]))

    print('训练集数量', len(train.targets))
    print('测试集数量', len(valid.targets))
    print('数据集形状', train[0][0].shape)
    print('数据集类别', train.class_to_idx)
def test02():
    train=CIFAR10(root='root', train=True, download=True, transform=Compose([ToTensor()]))
    valid=CIFAR10(root='root', train=True, download=True, transform=Compose([ToTensor()]))
    dataloader=DataLoader(train,batch_size=8, shuffle=True)
    for x,y in dataloader:
        print(x.shape)
        print(y)
        break

加载数据

通过from torchvision.datasets import CIFAR10导入了CIFAR10数据集的模块, 然后,通过from torchvision.transforms import Compose导入了数据转换的模块, 接着,通过from torchvision.transforms import ToTensor导入了将数据转换为张量的模块, 最后,通过from torch.utils.data import DataLoader导入了一个用于加载数据的工具。

函数test01()主要是用来测试并打印CIFAR10数据集的信息。其中,使用CIFAR10()函数加载训练集和测试集,并传入了一些参数:

  • root='root'表示指定存储路径为"root"文件夹;
  • train=True表示加载训练集;
  • download=True表示如果不存在该数据集则下载;
  • transform=Compose([ToTensor()])表示对每个样本应用ToTensor()函数进行转换。

然后,输出训练集数量、测试集数量、第一个样本的形状以及数据集中各个类别对应的标签。

函数test02()主要是用来测试并打印CIFAR10训练集中批次样本的信息。同样地,使用CIFAR10()函数加载训练集和测试集,并传入相同的参数。 接着,使用DataLoader()函数创建一个数据加载器,并传入训练集、批次大小(batch_size)和是否打乱数据(shuffle=True)的参数。 然后,使用for x,y in dataloader:遍历数据加载器中的每个批次样本,并分别打印出每个批次样本的形状和标签。 最后,使用break语句结束循环,只打印第一个批次样本的信息。

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.optim as optim
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import ssl
ssl._create_default_https_context=ssl._create_unverified_context()
class ImageClassification(nn. Module):
    def __init__(self):
        super(ImageClassification,self).__init__()
        #定义卷积池化
        self.conv1=nn.Conv2d(3, 6, stride=1, kernel_size=3)
        self.pool1=nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2=nn.Conv2d(6, 16, kernel_size=3, stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        #定义线性层
        self.linear1=nn.Linear(576, 120)
        self.linear2=nn.Linear(120,84)
        self.out=nn.Linear(84, 10)
    def forward(self, x):
        x=self.conv1(x)
        x=F.relu(x)
        x=self.pool1(x)

        x=self.conv2(x)
        x=F.relu(x)
        x=self.pool2(x)
        #进入全连接层,进行维度变化。
        x=x.reshape(x.size(0), -1)
        x = self.linear1(x)
        x=F.relu(x)

        x=self.linear2(x)
        x=F.relu(x)
        return self.out(x)
#开始训练,写训练函数
def train():
    cifar10=CIFAR10(root='data', train=True, download=True, transform=Compose([ToTensor()]))
    model=ImageClassification()
    #损失函数
    criterion=nn.CrossEntropyLoss
    #优化方法
    optimizer=optim.Adam(model.parameters(), lr=1e-3)
    epochs=100
    for epoch_idx in range(epochs):
        dataloader=DataLoader(cifar10,batch_size=32,shuffle=True)
        sam_num=0
        total_loss=0.0
        start=time.time()
        correct=0
        for x, y in dataloader:
            output = model(x)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #统计信息
            correct+=(torch.argmax(output,dim=1)==y).sum()
            total_loss+=loss.item()*len(y)
            sam_num += len(y)
            print('epoch: %2s loss:%.5f acc:%.2fs time:%.2fs'%(epoch_idx+1, total_loss/ sam_num,correct/ sam_num, time.time()-start))
#模型的保存
    torch.save(model.state_dict(),'model/image_classification.pth')

    if __name__ == '__main__':
        train()

训练模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值