PyTorch实现CIFAR-10分类代码

这篇是PyTorch学习之路第七篇,用于记录PyTorch实现CIFAR-10分类代码

(书上的代码有好多冗余)

下面实例数据集位于:C:\Users\22130\Learning_Pytorch\dataset

完整代码(还未训练)

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

train_batch_size = 4
test_batch_size = 4
num_workers = 0 #线程数
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
lr = 0.001
momentum = 0.9

#加载数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)
test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)
train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)
test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)

#数据可视化
import matplotlib.pyplot as plt
import numpy as np
plt.figure()
def imshow(img):
    img = img/2 +0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
examples = enumerate(train_loader)
idx, (examples_data, examples_target) = next(examples) #examples_target是标签列表,0-9表示不同的类别
imshow(torchvision.utils.make_grid(examples_data))
#用于具体查看examples
print('--------------测试examples------------')
print('examples_target.shape:{}'.format(examples_target.shape))
print('examples_target[0]:{}'.format(examples_target[0]))
print('examples_data.shape:{}'.format(examples_data.shape))


#构建网络
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
        #self.aap = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(1296,128)
        self.fc2 = nn.Linear(128,10)
        #self.fc3 = nn.Linear(36,10)
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        #x = self.aap(x)
        #x = x.view(x.shape[0],-1)
        #x = self.fc3(x)
        x = x.view(-1,36*6*6)
        #print("x.shape:{}".format(x.shape))
        x = F.relu(self.fc2(F.relu(self.fc1(x))))
        return x

model = CNNNet()
model = model.to(device)
print('--------------查看网络结构-----------')
print(model)

#--训练模型--
print('-----训练优化器-------')
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
print("----------正式训练模型---------")
losses = []
acces = []
eval_losses = []
eval_acces = []
for epoch in range(10):
    train_acc = 0
    train_loss = 0
    num_correct = 0
    model.train()
    for i, data in enumerate(train_loader):
        img, label = data
        img, label = img.to(device), label.to(device)
        #权重参数梯度清零
        optimizer.zero_grad()
        #正向反向传播
        out = model(img)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()

        #计算损失值
        train_loss += loss.item()
        #计算准确率
        _, pred = out.max(1)
        num_correct += (pred == label).sum()
        if i % 2000 == 1999:
            print('[%d,%5d] loss : %.3f' % (epoch + 1, i + 1, train_loss / 2000))
            train_loss = 0.0
    acces.append(num_correct/(len(train_loader)*train_batch_size))
#精确率可视化
plt.title('Train Acc')
plt.plot(np.arange(len(acces)),acces)
plt.legend(['Train Acc'],loc='upper right')
plt.show()
#测试模型
eval_loss = 0
eval_acc = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
total = 0
model.eval()
with torch.no_grad():
    for img, label in test_loader:
        img, label = img.to(device), label.to(device)
        out = model(img)
        #计算损失值
        loss = criterion(out,label)
        eval_loss += loss.item()
        #计算准确率
        _, pred = out.max(1)
        #print("len(label):{}".format(len(label)))
        num_correct += (pred == label).sum()
        c = (pred == label).squeeze()
        acc = num_correct/len(label)
        eval_acc += acc
        total += label.size(0)
        #计算各类别准确率
        for i in range(4):
            class_correct[label[i]] += c[i].item()
            class_total[label[i]] += 1
    eval_losses.append(eval_loss/total)
    eval_acces.append(eval_acc/total)
    print("total:{}".format(total))
    print("len(test_loader):{}".format(len(test_loader)))
    for i in range(10):
        print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))
    print("----------------")
    print('epoch:{}, eval_loss:{:.4f},eval_acc:{:.4f}'.format(epoch,eval_loss/len(test_loader),eval_acc/len(test_loader)))
    print("Accuracy of the network on the 10000 test images:%d %%" % (100 * eval_acc / len(test_loader)))

完整代码(已训练,直接载入模型)

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np

print(torch.cuda.device_count())
train_batch_size = 4
test_batch_size = 4
num_workers = 0 #线程数
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
lr = 0.001
momentum = 0.9

#加载数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
#train_dataset = torchvision.datasets.CIFAR10('./dataset',train=True,transform=transform,download=True)
test_dataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=transform,download=True)
#train_loader = DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True,num_workers=num_workers)
test_loader = DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False,num_workers=num_workers)


#构建网络
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class CNNNet(nn.Module):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=36,kernel_size=3,stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
        #self.aap = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(1296,128)
        self.fc2 = nn.Linear(128,10)
        #self.fc3 = nn.Linear(36,10)
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        #x = self.aap(x)
        #x = x.view(x.shape[0],-1)
        #x = self.fc3(x)
        x = x.view(-1,36*6*6)
        #print("x.shape:{}".format(x.shape))
        x = F.relu(self.fc2(F.relu(self.fc1(x))))
        return x

model = CNNNet()
#加载模型
model.load_state_dict(torch.load('./model/model.pth'))#再加载网络的参数
model = model.to(device)
print("load success")
print('--------------查看网络结构-----------')
print(model)

#测试模型
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
num_correct = 0
model.eval()
with torch.no_grad():
    for img, label in test_loader:
        img, label = img.to(device), label.to(device)
        out = model(img)
        #计算准确率
        _, pred = out.max(1)
        num_correct += (pred == label).sum()
        #计算各类别准确率
        c = (pred == label)
        for i in range(4):
            class_correct[label[i]] += c[i].item() #将True/False化为1/0
            class_total[label[i]] += 1
    print("精确率为:{}".format(num_correct/(len(test_loader)*test_batch_size)))
    for i in range(10):
        print("accuracy of {}:{}%".format(classes[i],100*class_correct[i]/class_total[i]))


  • 2
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值