3.2 使用pytorch搭建AlexNet并训练花分类数据集


#详解
在这里插入图片描述

class_indices.json

{
    "0": "daisy",
    "1": "dandelion",
    "2": "roses",
    "3": "sunflowers",
    "4": "tulips"
}

model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):# 继承model类
    def __init__(self, num_classes=1000, init_weights=False):#初始化参数,定义参数与层结构
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(#Sequential能把一系列的层结构打包成一个新的层结构,当前层结构被定义为提取特征的层结构
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), #第一层 # input[3, 224, 224]  output[48, 55, 55],他只用了一半的卷积核(padding=(1,2),计算后是小数,就又一样了)
            nn.ReLU(inplace=True),#inplace是pytorch通过一种操作增加计算量减少内存占用
            nn.MaxPool2d(kernel_size=3, stride=2),#卷积核大小是3,步距是2                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(#分类器
            nn.Dropout(p=0.5),#dropout的方法上全连接层随机失活(一般放在全裂阶层之间)p值随即失火的比例
            nn.Linear(128 * 6 * 6, 2048),#linear是全连接层
            nn.ReLU(inplace=True),#激活函数
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),#输出是数据集的类别个数
        )
        if init_weights:#初始化权重,定义在下面
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)#展平
        x = self.classifier(x)
        return x

    def _initialize_weights(self):#其实不用,目前pytorch自动就是这个
        for m in self.modules():#会返回一个迭代器,遍历模型中所有的模块(遍历每一个层结构)
            if isinstance(m, nn.Conv2d):#是否是卷积
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')#是就去kaiming_normal初始化
                if m.bias is not None:#偏置不是0就置0
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):#如果是全连接层
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))#压缩掉batch的维度
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#指定训练过程中使用的设备
print(device)

data_transform = {#数据预处理函数
    "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪到224*224像素大小
                                 transforms.RandomHorizontalFlip(),#水平随即反转
                                 transforms.ToTensor(),#转化成tensor
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path#获取数据集的根目录(os.getcwd():获取当前稳健所在目录;os.path.join合并到一起)
image_path = data_root + "/data_set/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",#加载数据集,train下面每一类是一个文件夹
                                     transform=data_transform["train"])#transform是数据预处理(之前定义的),map
train_num = len(train_dataset)#数据集有多少张图片

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx####################!!!!!获取类的名称对应的索引
cla_dict = dict((val, key) for key, val in flower_list.items())#江键值反转
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)#编码成json的格式
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=4, shuffle=True,
                                              num_workers=0)

#查看数据集的代码
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
#
# def imshow(img):
#     img = img / 2 + 0.5  # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
#
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))


net = AlexNet(num_classes=5, init_weights=True)#实例化,分类集有五类,初始化权重是true

net.to(device)#分设备
#损失函数与优化器
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters())
optimizer = optim.Adam(net.parameters(), lr=0.0002)

save_path = './AlexNet.pth'
best_acc = 0.0#用来保存最佳平均准确率,为了保存效果最好的一次模型
for epoch in range(10):#10轮
    # train
    net.train()#用net.train()与net.eavl() 因为用了dropout,希望只在训练时失活,所以用这个来管理dropout
    running_loss = 0.0
    t1 = time.perf_counter()#统计训练一个epoch所使用的时间
    for step, data in enumerate(train_loader, start=0):#遍历数据集
        images, labels = data#将数据分成图像与对应的标签
        optimizer.zero_grad()#清空梯度信息
        outputs = net(images.to(device))#正向传播,将图像也指认到设备上
        loss = loss_function(outputs, labels.to(device))#计算损失
        loss.backward()#反向传播
        optimizer.step()#更新参数

        # print statistics
        running_loss += loss.item()#将loss的值累加到runningloss中(loss。item才是loss的值)
        # print train process,打印训练进度
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1)

    # validate验证
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():#禁止损失跟踪
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()#计算准确个数
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)#保存权重
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

创建自己的数据集

偷懒的办法
在flowerdata下直接全删了改自己的
在这里插入图片描述
在这里插入图片描述

  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值