CIFAR10数据集训练及测试2

 一、数据集解析

#数据读取与解析
import pickle
import numpy as np
import cv2

def unpickle(file):
    with open(file,'rb') as fo:
        dict = pickle.load(fo,encoding='bytes')
    return dict
label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

import glob
import numpy
import os
#解析训练集
im_train_list = glob.glob("G:\pycharm-work\CIFAR10\data_batch_*")
print(im_train_list)
save_path = "G:\pycharm-work\CIFAR10\TRAIN"

for l in im_train_list:
    print(l)
    l_dict = unpickle(l)
    print(l_dict.keys())

    for im_idx , im_data in  enumerate(l_dict[b'data']):
        # print(im_idx)
        # print(im_data)
        im_label = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]
        # print(im_label,im_name)
        im_label_name = label_name[im_label]
        #对数据进行reshape 则需要转为numpy格式
        im_data = np.reshape(im_data,[3,32,32])
        im_data = np.transpose(im_data,(1,2,0))
        # cv2.imshow("im_data",cv2.resize(im_data,(200,200)))
        # cv2.waitKey(0)

        if not os.path.exists("{}/{}".format(save_path,im_label_name)):
            os.mkdir("{}/{}".format(save_path,im_label_name))
        cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,
                                      im_name.decode("utf-8")),im_data)

#解析训练集
im_test_list = glob.glob("G:\pycharm-work\CIFAR10\test_batch_*")
print(im_test_list)
save_path = "G:\pycharm-work\CIFAR10\TEST"

for l in im_test_list:
    print(l)
    l_dict = unpickle(l)
    print(l_dict.keys())

    for im_idx , im_data in  enumerate(l_dict[b'data']):
        # print(im_idx)
        # print(im_data)
        im_label = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]
        # print(im_label,im_name)
        im_label_name = label_name[im_label]
        #对数据进行reshape 则需要转为numpy格式
        im_data = np.reshape(im_data,[3,32,32])
        im_data = np.transpose(im_data,(1,2,0))
        # cv2.imshow("im_data",cv2.resize(im_data,(200,200)))
        # cv2.waitKey(0)

        if not os.path.exists("{}/{}".format(save_path,im_label_name)):
            os.mkdir("{}/{}".format(save_path,im_label_name))
        cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,
                                      im_name.decode("utf-8")),im_data)

 二、自定义数据集加载

import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
import time
import os
from PIL import Image
import numpy as np
import cv2

label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

label_dict = {}

for idx, name in enumerate(label_name):
    label_dict[name] = idx
# print(label_dict)

def default_loader(path):
    return Image.open(path).convert("RGB")

train_transform = transforms.Compose([
    transforms.RandomResizedCrop((28,28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3,0.3,0.3,0.3),
    transforms.ToTensor()
]) 
class MyDataset(Dataset):
    def __init__(self,im_list,transform = None,loader = default_loader):
        super(MyDataset,self).__init__()
        imgs = []
        for im_item in im_list:
            # print(im_item)
            im_label_name = im_item.split("\\")[-2]
            imgs.append([im_item,label_dict[im_label_name]])

        self.imgs =imgs
        self.transform =transform
        self.loader = loader

    def __getitem__(self, index):
        im_path, im_label = self.imgs[index]

        im_data = self.loader(im_path)

        if self.transform is not None:
            im_data = self.transform(im_data)
        return im_data,im_label
    def __len__(self):
        return len(self.imgs)

im_train_list = glob.glob("G:\pycharm-work\CIFAR10\TRAIN\*\*.png")

im_test_list = glob.glob("G:\pycharm-work\CIFAR10\TEST\*\*.png")

train_data_set = MyDataset(im_train_list,transform=train_transform)

test_data_set = MyDataset(im_test_list,transform = transforms.ToTensor())

# train_data_loader = DataLoader(dataset=train_data_set,batch_size=6,shuffle=True,num_workers=4)

# test_data_loader = DataLoader(dataset=test_data_set,batch_size=6,shuffle=False,num_workers=4)

train_data_loader =DataLoader(dataset=train_data_set,batch_size=128,shuffle=True)
test_data_loader = DataLoader(dataset=test_data_set,batch_size=128,shuffle=False)

print("num_of_train",len(train_data_set))

print("num_of_test",len(test_data_set))

 三、搭建神经网络模型

class VGGBase(nn.Module):
    def __init__(self):
        super(VGGBase, self).__init__()
        # 3 * 28 * 28
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        #14*14
        self.max_pooling1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

        self.conv2_1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.conv2_2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        # 7*7
        self.max_pooling2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.conv3_2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)# 获得4 * 4

        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.conv4_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.max_pooling4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc = nn.Linear(512 * 4,10)

    def forward(self,x):
        batchsize = x.size(0)
        out  = self.conv1(x)
        out  = self.max_pooling1(out)
        out = self.conv2_1(out)
        out = self.conv2_2(out)
        out = self.max_pooling2(out)

        out = self.conv3(out)
        out = self.conv3_2(out)
        out = self.max_pooling3(out)

        out = self.conv4(out)
        out = self.conv4_2(out)
        out = self.max_pooling4(out)

        out = out.view(batchsize, -1)

        out = self.fc(out)

        out = F.log_softmax(out,dim=1)

        return out

def VGGNet():
    return VGGBase()

四、训练与测试

# 定义训练设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net =VGGNet()
net = net.to(device)
# 损失函数
loss_fn =nn.CrossEntropyLoss()
loss_fn =loss_fn.to(device)

# 优化器
learning_rate =0.01
#weight_decay 正则项  momentum 动量
# optimizer =torch.optim.SGD(net.parameters(),lr = learning_rate,momentum=0.9,weight_decay=5e-4)
optimizer =torch.optim.Adam(net.parameters(),lr = learning_rate)
#学习率调整  指数衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.9)

# 添加Tensorboard
if not os.path.exists("log"):
    os.mkdir("log")
writer  =SummaryWriter("log")
# 设置训练网络的一些参数
# 记录训练的次数
total_train_step  = 0
# 记录测试的次数
total_test_step  = 0

epoch = 30
batch_size = 128

start_time = time.time()

step_n = 0
for epochidx in range(epoch):
    print("-----第{}轮训练开始------".format(epochidx + 1))
    # 训练步骤开始
    net.train()  #BN与dropout 更新
    for i,data in enumerate(train_data_loader):
        imgs,target =data
        imgs = imgs.to(device)
        target = target.to(device)
        output = net(imgs)
        loss = loss_fn(output, target)

        #优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # total_train_step +=1
        # if total_train_step % 100 == 0:
        #     end_time =time.time()
        #     print(end_time-start_time)
        #     print("训练次数{}, Loss:{}".format(total_train_step, loss.item()))
        #     writer.add_scalar("train_loss",loss.item(),total_train_step)

        _,pred = torch.max(output.data,dim= 1)

        correct  =  pred.eq(target.data).cpu().sum()
        print("train epoch is ",epochidx)
        print("trainlr is ",optimizer.state_dict()["param_groups"][0]["lr"])
        print("train step",i,"loss is:",loss.item(),
              "mini_batch correct is:",100.0 * correct / batch_size)
        #记录
        writer.add_scalar("train loss",loss.item(),global_step=step_n)
        writer.add_scalar("train correct", 100.0 * correct / batch_size,global_step=step_n)
        im = torchvision.utils.make_grid(imgs)
        writer.add_image("train img",im,global_step=step_n)
        step_n +=1
    if not os.path.exists("models"):
        os.mkdir("models")
    torch.save(net.state_dict(),"models/{}.pth".format(epochidx+1))
    #每个epoch后更新学习率
    scheduler.step()

    #对模型进行测试
    sum_loss = 0
    sum_correct = 0
    for j, data in enumerate(test_data_loader):
        net.eval()  # 测试
        imgs, target = data
        imgs = imgs.to(device)
        target = target.to(device)
        output = net(imgs)
        loss = loss_fn(output, target)

        _, pred = torch.max(output.data, dim=1)

        correct = pred.eq(target.data).cpu().sum()

        sum_loss += loss.item()
        sum_correct +=correct.item()

        im = torchvision.utils.make_grid(imgs)
        writer.add_image("test img", im, global_step=step_n)
    test_loss =sum_loss*1.0/len(test_data_loader)
    test_correct = sum_correct * 100.0 / len(test_data_loader)/batch_size

    # 记录
    writer.add_scalar("test loss", test_loss, global_step=j+1)
    writer.add_scalar("test correct", test_correct, global_step=j+1)
    print("epoch is", j+1, "loss is:", test_loss,
              "test correct is:", test_correct)
writer.close()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值