小白学Pytorch使用(4-3):花数据集分类——自定义DataLoader

任务背景

利用resnet18网络结构及预训练模型参数进行102类别的花数据集分类——自定义DataLoader处理数据集(训练部分与4-2相同,数据集汇总打乱)
数据如下:
数据集+模型

一、导入库

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
# torchvision中的transforms模块自带数据增强、数据预处理功能;models预训练模型,如resnet模型;datasets文件夹
from torchvision import transforms, models, datasets
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import copy
import json
from PIL import Image

from torch.utils.data import Dataset, DataLoader

二、自定义DataLoader

# DataLoader需要两个列表:图片路径列表、图片对应标签列表
# 定义一个花数据集处理类,名字随意
class FlowerDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None):
        # 图片名称与类别对应文件
        self.ann_file = ann_file
        # 图片所在文件路径
        self.root_dir = root_dir
        # 读取图片名称与类别对应文件,获得图片名称与类别对应字典,key为图片名称,value为对应标签
        self.img_label = self.loaddata()
        # 字典中的所有关键字图片名称转为列表类型--->图片路径+图片名称--->得到图片路径列表
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
        # 字典中的所有值转为列表类型--->得到标签列表
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform

    def __len__(self):
        return len(self.img)

    # 对所有数据下标打乱划分为多个batch后依次返回下标idx输入,(假设batch为64,图像尺寸为256)返回64*3*256*256图像数据,64*1标签数据
    def __getitem__(self, idx):
        # 下标对应的图片
        image = Image.open(self.img[idx])
        # 下标对应标签
        label = self.label[idx]
        # transform对数据进行预处理操作输出tensor格式
        # print(image)
        # <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=704x500 at 0x2C3BC372BE0>
        if self.transform:
            image = self.transform(image)
            # print('*******************************************************')
            # print(image)
            # tensor([[[ 0.6392,  0.7077,  0.7419,  ...,  1.2385,  0.2453,  0.6049],
            #          [ 0.6734,  0.7077,  0.7419,  ...,  1.5639,  0.2624,  0.9132],
            #          [ 0.6563,  0.7077,  0.7077,  ...,  1.4098,  0.4851,  0.4851],
            #          ...,
            #          [-0.0287, -0.0287,  0.0227,  ...,  0.3652,  0.9132,  0.2796],
            #          [-0.1999, -0.1657, -0.1999,  ...,  0.9303, -0.1486,  0.5364],
            #          [-0.1999, -0.2684, -0.1999,  ...,  0.9303, -0.0287,  0.0912]],
            #
            #         [[ 0.7654,  0.8004,  0.8179,  ...,  1.5357,  0.5203,  0.9755],
            #          [ 0.7654,  0.8004,  0.8179,  ...,  1.9559,  0.5378,  1.2906],
            #          [ 0.7479,  0.7829,  0.8004,  ...,  1.9559,  0.8704,  0.8704],
            #          ...,
            #          [ 0.6954,  0.6954,  0.8704,  ...,  0.8529,  1.1506,  0.5378],
            #          [ 0.5728,  0.6254,  0.5553,  ...,  1.4657,  0.3102,  0.9405],
            #          [ 0.5903,  0.5553,  0.5903,  ...,  1.4657,  0.5378,  0.7479]],
            #
            #         [[ 0.9494,  0.9842,  1.0017,  ...,  1.0365, -0.1312,  0.3568],
            #          [ 0.9668,  0.9668,  1.0017,  ...,  1.1411, -0.0092,  0.6008],
            #          [ 0.9494,  0.9842,  1.0017,  ...,  0.7925,  0.2173,  0.2173],
            #          ...,
            #          [-0.0267, -0.0267,  0.0256,  ..., -0.2358,  0.0953, -0.3404],
            #          [-0.1835, -0.0964, -0.1312,  ..., -0.0092, -0.5670, -0.1487],
            #          [-0.2184, -0.2358, -0.1312,  ..., -0.0092, -0.3753, -0.2707]]])
            # print('*******************************************************')
        # 标签转为tensor格式,分类任务标签不变,不需要进行预处理操作
        label = torch.from_numpy(np.array(label))
        return image, label


    # 读取图片名称与类别对应文件,获得图片名称与类别对应字典,key为图片名称,value为对应标签
    def loaddata(self):
        # 空字典
        datadict = {}
        # 读取图片名称与类别对应文件
        with open(self.ann_file) as f:
            # 对文件中每一行以空格进行分割,转为列表形式。[图片名称,标签,图片名称,标签,图片名称,标签,...]
            samples = [x.strip().split(' ') for x in f.readlines()]
            # 转换为字典格式
            for image_name, label in samples:
                datadict[image_name] = np.array(label, dtype=np.int64)
        return datadict


# 数据集量较少,需进行数据增强(Data Augmentation):旋转、裁剪、翻转(水平/垂直)、平移
# 数据预处理操作,可自己根据实际情况改动
data_transforms = {
    'train':
        # Compose()顺序执行以下操作
        transforms.Compose([
        # 数据集尺寸不同,统一尺寸,可正方形(常用)、长方形,一般64、128、224、256
        # 数据越小易损失特征,但训练速度加快;CPU一般96/64
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),  #随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),      #从中心开始裁剪,实际输入网络的图像大小以裁剪后大小为准,如裁剪后为64*64,则输入大小为64*64
        transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转 选择一个概率概率,每张图都有50%可能性水平翻转
        transforms.RandomVerticalFlip(p=0.5),   #随机垂直翻转,每张图都有50%可能性垂直翻转
        # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#(极端光照条件下使用)参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        # transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B,转成RRR/GGG/BBB
        transforms.ToTensor(),      #数据转换为tensor格式
        # 标准化操作。由于数据量较少不具有代表性,选用自己的均值标准差结果不稳定,因此选择大数据中提供的标准差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#r、g、b三通道的均值u,标准差b,(x-u)/b
    ]),

    # 验证集测试模型实际训练结果,不需要数据增强
    'valid':
        transforms.Compose([
        # 与训练集输入尺寸(裁剪后尺寸)相同
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        # 与训练集均值标准差相同
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 训练集标签文件路径
train_file_path = r'D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第六章:DataLoader自定义数据集制作\flower_data\train.txt'
# 验证集标签文件路径
val_file_path = r'D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第六章:DataLoader自定义数据集制作\flower_data\val.txt'

# 训练集和验证集数据路径
path = r'D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第六章:DataLoader自定义数据集制作\flower_data'
train_path = path + '/train_data'
val_path = path + '/valid_data'

# 训练集和验证集数据处理
train_dataset = FlowerDataset(root_dir=train_path, ann_file=train_file_path, transform=data_transforms['train'])
valid_dataset = FlowerDataset(root_dir=val_path, ann_file=val_file_path, transform=data_transforms['valid'])

# DataLoader划分数据集
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

三、数据集与标签划分验证

随机选取训练集一个batch中的一个数据,压缩1维度,转为numpy格式,进行反标准化操作,展示图像,输出图像对应标签

# iter(train_loader)迭代train_loader数据,next()随机取一个batch数据
image, label = next(iter(train_loader))
# print(image.shape)
# torch.Size([64, 3, 64, 64])
# print(label)
# tensor([ 43,  72,  74,  87,  79,  27,   0,  59,  13,  63,  72,  68,  78,  87,
#          77,  72,  89,  31,  16,  82,  99,  82, 101,  50,  57,  69,  59,  79,
#           3,  50,  95,  73,  82,   2,  56,  70,  18,  87,  46,  73,  94,  90,
#          52,  75,  85,  98,  51,  36,  40,  97,  16,  86,  50,  57,  55,  80,
#          89,  28,   1,  57,  63,  16,  47,   1])

# 将维度为1的维度去除
sample = image[0].squeeze()
# print(sample.shape)
# torch.Size([3, 64, 64])
# sample的维度为torch.Size([3, 64, 64]),转换为numpy格式[64, 64, 3]
sample = sample.permute((1, 2, 0)).numpy()
# 反标准化
sample = sample * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
sample = sample.clip(0, 1)
plt.imshow(sample)
plt.show()
print('label is :{}'.format(label[0].numpy()))

结果展示:
随机图片
随机图片对应标签

三、网络训练

dataloaders = {'train':train_loader, 'valid':val_loader}

# 打开cat_to_name.json文件,文件中有数字对应的实际类别名称
with open('D:/咕泡人工智能-配套资料\配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)
    # print(cat_to_name)
'''
{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers',
 '7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger', 
 '67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist', 
 '9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy', 
 '15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower',
 '62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper', 
 '42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 'sunflower', '66': 'osteospermum', '70': 'tree poppy', '85': 'desert-rose', 
 '99': 'bromelia', '87': 'magnolia', '5': 'english marigold', '92': 'bee balm', '28': 'stemless gentian', '97': 'mallow', '57': 'gaura', '40': 'lenten rose', 
 '47': 'marigold', '59': 'orange dahlia', '48': 'buttercup', '55': 'pelargonium', '36': 'ruby-lipped cattleya', '91': 'hippeastrum', '29': 'artichoke', '71': 'gazania', 
 '90': 'canna lily', '18': 'peruvian lily', '98': 'mexican petunia', '8': 'bird of paradise', '30': 'sweet william', '17': 'purple coneflower', '52': 'wild pansy', 
 '84': 'columbine', '12': "colt's foot", '11': 'snapdragon', '96': 'camellia', '23': 'fritillary', '50': 'common dandelion', '44': 'poinsettia', '53': 'primula', 
 '72': 'azalea', '65': 'californian poppy', '80': 'anthurium', '76': 'morning glory', '37': 'cape flower', '56': 'bishop of llandaff', '60': 'pink-yellow dahlia', 
 '82': 'clematis', '58': 'geranium', '75': 'thorn apple', '41': 'barbeton daisy', '95': 'bougainvillea', '43': 'sword lily', '83': 'hibiscus', '78': 'lotus lotus', 
 '88': 'cyclamen', '94': 'foxglove', '81': 'frangipani', '74': 'rose', '89': 'watercress', '73': 'water lily', '46': 'wallflower', '77': 'passion flower', 
 '51': 'petunia'}
'''

# 迁移学习:使用前人的网络结构和模型做训练
# 数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层
# 数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练
# 数据量较大时,整个模型不进行冻结,全部更新训练

# 加载models中提供的模型,并且直接用训练好的权重当作初始化参数
#可选网络结构比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],resnet网络效果较好
model_name = 'resnet'

# 是否用人家训练好的特征来做
# 此项目冻结输出层前面所有部分,不进行训练更新
feature_extract = True

# 是否用GPU训练——————固定写法
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 使用18层的resnet网络结构,18层的能快点,条件好点的也可以选152
model_ft = models.resnet18()
# resnet18网络结构
# print(model_ft)
'''
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  # resnet18输出层为1000个类别,此次项目预测类别为102,后续修改输出层
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''

# 自定义函数判断模型参数是否要进行更新
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        # 对18层resnet网络结构中的权重参数进行遍历
        for param in model.parameters():
            # requires_grad反向传播时是否更新参数,True进行更新
            param.requires_grad = True

# for param in model_ft.parameters():
#     param.requires_grad = False
# for name, param in model_ft.named_parameters():
#     # 输出所有权重参数名称及参数值
#     print(name)
#     print(param)
    # 如输出层权重参数:
'''
fc.weight
Parameter containing:
tensor([[ 0.0018, -0.0213, -0.0421,  ...,  0.0064, -0.0055, -0.0221],
        [-0.0298, -0.0036,  0.0277,  ...,  0.0181, -0.0311, -0.0340],
        [-0.0434, -0.0196, -0.0026,  ...,  0.0229,  0.0110,  0.0345],
        ...,
        [-0.0097,  0.0290,  0.0391,  ..., -0.0091, -0.0028, -0.0373],
        [ 0.0073,  0.0256,  0.0238,  ...,  0.0004, -0.0275, -0.0347],
        [-0.0356, -0.0144, -0.0418,  ...,  0.0020, -0.0338,  0.0351]])
fc.bias
Parameter containing:
tensor([ 0.0226,  0.0271, -0.0001, -0.0043, -0.0339, -0.0216,  0.0056, -0.0067,
        -0.0024,  0.0062, -0.0425,  0.0148,  0.0340, -0.0365,  0.0084,  0.0407,
        -0.0335,  0.0133,  0.0155, -0.0303, -0.0304,  0.0291, -0.0440, -0.0278,
        -0.0324, -0.0041,  0.0068, -0.0349, -0.0380,  0.0169,  0.0203, -0.0439,
         0.0185, -0.0206,  0.0164, -0.0020,  0.0067, -0.0101, -0.0166,  0.0155,
        -0.0302,  0.0095, -0.0119,  0.0124,  0.0065,  0.0435, -0.0298, -0.0022,
        -0.0379,  0.0406, -0.0065, -0.0210,  0.0151,  0.0111, -0.0251, -0.0374,
        -0.0246, -0.0052, -0.0139, -0.0246,  0.0008,  0.0318,  0.0358,  0.0370,
         0.0277, -0.0124,  0.0391, -0.0158,  0.0052, -0.0437,  0.0076, -0.0352,
        -0.0045,  0.0270, -0.0332, -0.0177, -0.0420, -0.0116, -0.0224,  0.0410,
        -0.0063,  0.0408, -0.0062, -0.0034, -0.0084, -0.0306,  0.0077,  0.0284,
         0.0437, -0.0336, -0.0037, -0.0261, -0.0234, -0.0166,  0.0153, -0.0213,
        -0.0396,  0.0089,  0.0307,  0.0027, -0.0179, -0.0098,  0.0343,  0.0375,
        -0.0432,  0.0126, -0.0060,  0.0370, -0.0030,  0.0219,  0.0140, -0.0165,
         0.0409, -0.0134, -0.0352, -0.0265, -0.0312, -0.0205,  0.0268,  0.0429,
        -0.0260,  0.0252,  0.0016, -0.0357,  0.0052, -0.0040,  0.0173, -0.0442,
        -0.0089, -0.0116,  0.0166, -0.0381,  0.0143,  0.0032, -0.0010, -0.0092,
         0.0130, -0.0224, -0.0258,  0.0187,  0.0179,  0.0061,  0.0297,  0.0404,
         0.0218,  0.0116,  0.0345, -0.0209, -0.0256,  0.0091, -0.0302, -0.0306,
         0.0239,  0.0412,  0.0408,  0.0253, -0.0016, -0.0182, -0.0015,  0.0042,
         0.0150, -0.0024,  0.0138, -0.0164,  0.0076, -0.0181,  0.0339,  0.0179,
        -0.0289, -0.0327, -0.0399,  0.0419,  0.0386, -0.0345, -0.0321, -0.0413,
         0.0390, -0.0339,  0.0359,  0.0012,  0.0298,  0.0134, -0.0128, -0.0251,
         0.0435, -0.0231, -0.0432,  0.0385, -0.0423, -0.0137,  0.0170, -0.0412,
        -0.0413,  0.0114,  0.0428, -0.0425, -0.0089, -0.0290, -0.0112,  0.0277,
        -0.0109,  0.0083,  0.0324, -0.0163, -0.0389, -0.0206,  0.0052, -0.0091,
         0.0151,  0.0244,  0.0010, -0.0173, -0.0015, -0.0437, -0.0377, -0.0107,
        -0.0185,  0.0059, -0.0198, -0.0395, -0.0129, -0.0363,  0.0408,  0.0418,
        -0.0230, -0.0122, -0.0168,  0.0281,  0.0338,  0.0321,  0.0219, -0.0041,
         0.0307, -0.0244, -0.0007, -0.0307,  0.0387,  0.0051,  0.0417, -0.0241,
        -0.0165,  0.0186,  0.0210,  0.0373,  0.0192,  0.0415, -0.0230, -0.0091,
        -0.0324, -0.0416,  0.0304,  0.0065,  0.0398,  0.0036, -0.0232, -0.0392,
         0.0109, -0.0108, -0.0320, -0.0032, -0.0138,  0.0357, -0.0247,  0.0363,
        -0.0185, -0.0197, -0.0068, -0.0120, -0.0377, -0.0101,  0.0210, -0.0243,
         0.0269,  0.0128, -0.0142, -0.0385,  0.0185,  0.0233, -0.0051, -0.0172,
         0.0224, -0.0434,  0.0078,  0.0159, -0.0201, -0.0363, -0.0246,  0.0044,
        -0.0306, -0.0377, -0.0313,  0.0366,  0.0368, -0.0303,  0.0066,  0.0322,
        -0.0143,  0.0343,  0.0233, -0.0337, -0.0211, -0.0060, -0.0167, -0.0189,
        -0.0236,  0.0292,  0.0194,  0.0372, -0.0055,  0.0430,  0.0243, -0.0126,
         0.0208,  0.0273,  0.0145,  0.0269,  0.0020, -0.0070, -0.0102,  0.0016,
        -0.0191,  0.0397, -0.0001, -0.0044, -0.0360,  0.0095,  0.0357,  0.0089,
         0.0235, -0.0244,  0.0088,  0.0222,  0.0259,  0.0096,  0.0189,  0.0390,
         0.0401, -0.0208, -0.0358, -0.0197, -0.0248, -0.0088,  0.0085,  0.0018,
        -0.0132,  0.0289,  0.0287, -0.0400,  0.0063, -0.0054,  0.0399,  0.0123,
         0.0102, -0.0326,  0.0194, -0.0049,  0.0104, -0.0171, -0.0080,  0.0429,
        -0.0056, -0.0298,  0.0064,  0.0341, -0.0191,  0.0132, -0.0174,  0.0435,
         0.0035,  0.0093, -0.0321, -0.0366,  0.0307,  0.0088, -0.0395,  0.0357,
         0.0032, -0.0149, -0.0247,  0.0124,  0.0436,  0.0126, -0.0156,  0.0050,
         0.0109,  0.0183, -0.0404, -0.0018, -0.0104, -0.0395,  0.0390, -0.0306,
         0.0261, -0.0244, -0.0253, -0.0329,  0.0334,  0.0429, -0.0138, -0.0190,
        -0.0235, -0.0204, -0.0393,  0.0217, -0.0332, -0.0438, -0.0294, -0.0158,
        -0.0103,  0.0238, -0.0419, -0.0408,  0.0113,  0.0366,  0.0221, -0.0190,
        -0.0244,  0.0335,  0.0102, -0.0101, -0.0111, -0.0284, -0.0155, -0.0114,
         0.0137,  0.0019, -0.0006, -0.0074,  0.0137, -0.0260, -0.0037,  0.0199,
         0.0155, -0.0296,  0.0173,  0.0224,  0.0091, -0.0167,  0.0004,  0.0206,
        -0.0237, -0.0195,  0.0387, -0.0045,  0.0088,  0.0261, -0.0418, -0.0144,
        -0.0375, -0.0106, -0.0354,  0.0411, -0.0053, -0.0248, -0.0010,  0.0323,
        -0.0203,  0.0012,  0.0204, -0.0320,  0.0321,  0.0137,  0.0064, -0.0329,
         0.0051, -0.0340,  0.0171,  0.0422,  0.0266,  0.0238, -0.0164,  0.0103,
        -0.0413, -0.0355,  0.0127,  0.0207, -0.0240,  0.0398,  0.0323,  0.0217,
         0.0030,  0.0396,  0.0327, -0.0060,  0.0312, -0.0117, -0.0079,  0.0095,
         0.0423,  0.0010,  0.0018,  0.0233,  0.0434, -0.0210,  0.0049, -0.0072,
         0.0031,  0.0052, -0.0292, -0.0217, -0.0253,  0.0274, -0.0230, -0.0342,
        -0.0149,  0.0137, -0.0057,  0.0344, -0.0327,  0.0370,  0.0142, -0.0194,
         0.0109,  0.0366, -0.0046, -0.0203,  0.0088, -0.0117,  0.0263,  0.0020,
        -0.0335,  0.0387, -0.0196,  0.0386, -0.0433, -0.0012,  0.0138, -0.0383,
         0.0059, -0.0313, -0.0404, -0.0103, -0.0105, -0.0196, -0.0297,  0.0331,
         0.0372, -0.0275,  0.0007,  0.0266,  0.0006,  0.0255, -0.0035, -0.0365,
         0.0165, -0.0423,  0.0054,  0.0041,  0.0026, -0.0338,  0.0436, -0.0385,
         0.0346, -0.0295, -0.0315,  0.0096,  0.0376,  0.0166, -0.0351,  0.0077,
         0.0028,  0.0139,  0.0394,  0.0115, -0.0231, -0.0134,  0.0167,  0.0160,
        -0.0003, -0.0152,  0.0399,  0.0367, -0.0424,  0.0104, -0.0025, -0.0118,
         0.0020, -0.0115, -0.0253, -0.0290,  0.0042, -0.0025, -0.0155,  0.0101,
        -0.0170,  0.0135,  0.0283,  0.0441,  0.0294, -0.0083, -0.0428, -0.0267,
        -0.0247, -0.0344,  0.0177,  0.0173,  0.0029, -0.0369, -0.0150, -0.0193,
         0.0428, -0.0401,  0.0377,  0.0138, -0.0136, -0.0104,  0.0325,  0.0335,
        -0.0152, -0.0014, -0.0287,  0.0375, -0.0426,  0.0393,  0.0016, -0.0244,
         0.0394,  0.0212, -0.0019, -0.0024, -0.0212,  0.0249, -0.0244, -0.0144,
         0.0227,  0.0073,  0.0412, -0.0194, -0.0300,  0.0084, -0.0273,  0.0157,
        -0.0270, -0.0411,  0.0153, -0.0040, -0.0268,  0.0048,  0.0164, -0.0165,
        -0.0089,  0.0328,  0.0345, -0.0013, -0.0226, -0.0097, -0.0234,  0.0156,
         0.0236,  0.0076, -0.0349, -0.0442, -0.0293,  0.0132, -0.0420, -0.0020,
         0.0059, -0.0231, -0.0187, -0.0304,  0.0270,  0.0382,  0.0407, -0.0244,
         0.0394, -0.0204, -0.0356, -0.0287, -0.0221, -0.0330,  0.0249,  0.0247,
         0.0250, -0.0180, -0.0182, -0.0358, -0.0196, -0.0441,  0.0407, -0.0027,
         0.0177,  0.0167, -0.0258, -0.0097, -0.0210, -0.0106, -0.0331,  0.0008,
        -0.0392, -0.0378,  0.0418,  0.0093, -0.0442,  0.0419, -0.0233,  0.0076,
         0.0394,  0.0428,  0.0027, -0.0117, -0.0161,  0.0153,  0.0035,  0.0222,
         0.0375, -0.0300,  0.0200,  0.0285, -0.0120,  0.0074,  0.0054, -0.0117,
        -0.0161, -0.0088,  0.0428,  0.0103,  0.0136,  0.0376, -0.0172,  0.0086,
         0.0183,  0.0381,  0.0109, -0.0177, -0.0416, -0.0351,  0.0338, -0.0404,
         0.0324,  0.0041, -0.0172,  0.0002, -0.0415,  0.0423,  0.0172, -0.0068,
         0.0370,  0.0321,  0.0040,  0.0401,  0.0030,  0.0238,  0.0102, -0.0437,
         0.0204,  0.0432,  0.0399, -0.0186, -0.0067,  0.0205, -0.0098,  0.0009,
         0.0019,  0.0317, -0.0237,  0.0062,  0.0097, -0.0312, -0.0033,  0.0028,
        -0.0021,  0.0006, -0.0320, -0.0196,  0.0363,  0.0285,  0.0321, -0.0132,
         0.0344, -0.0232,  0.0379, -0.0166,  0.0311, -0.0174,  0.0431,  0.0006,
        -0.0066, -0.0344, -0.0076,  0.0245,  0.0286, -0.0388,  0.0114,  0.0204,
         0.0137,  0.0387,  0.0206, -0.0392, -0.0109,  0.0375,  0.0269,  0.0232,
        -0.0362,  0.0235, -0.0137,  0.0303,  0.0389, -0.0068,  0.0306,  0.0273,
         0.0264, -0.0074,  0.0315, -0.0291, -0.0027, -0.0061,  0.0188, -0.0123,
        -0.0360, -0.0266,  0.0292,  0.0248,  0.0127, -0.0251, -0.0426,  0.0066,
        -0.0005, -0.0162, -0.0236, -0.0330,  0.0339,  0.0319,  0.0135,  0.0260,
        -0.0389, -0.0375, -0.0192, -0.0079, -0.0066, -0.0261, -0.0441, -0.0042,
         0.0086,  0.0291,  0.0283,  0.0028, -0.0137, -0.0218,  0.0109, -0.0052,
        -0.0213, -0.0108, -0.0354, -0.0012,  0.0006, -0.0111, -0.0066,  0.0401,
        -0.0313,  0.0203,  0.0060,  0.0339,  0.0072,  0.0148,  0.0047,  0.0363,
        -0.0117,  0.0338,  0.0238, -0.0367,  0.0093,  0.0048,  0.0222,  0.0164,
         0.0399,  0.0005,  0.0056,  0.0023, -0.0330,  0.0241, -0.0403,  0.0047,
         0.0312, -0.0148,  0.0146,  0.0268,  0.0439, -0.0265, -0.0044,  0.0441,
         0.0395, -0.0164,  0.0117,  0.0343,  0.0355,  0.0243, -0.0091,  0.0083,
        -0.0213, -0.0196, -0.0082,  0.0070,  0.0403, -0.0407,  0.0270,  0.0375,
         0.0207,  0.0207, -0.0049,  0.0020,  0.0429, -0.0432, -0.0368, -0.0040,
         0.0035,  0.0114, -0.0345, -0.0286,  0.0232,  0.0342,  0.0437,  0.0193,
         0.0030,  0.0375,  0.0161,  0.0172, -0.0069, -0.0118, -0.0235, -0.0155,
        -0.0029, -0.0035,  0.0372, -0.0343,  0.0183, -0.0057, -0.0093, -0.0322,
        -0.0303,  0.0275, -0.0364, -0.0240, -0.0090, -0.0058, -0.0055,  0.0315,
        -0.0020,  0.0268, -0.0305, -0.0286, -0.0083,  0.0015, -0.0226,  0.0249,
        -0.0133, -0.0359,  0.0393,  0.0058, -0.0354,  0.0011,  0.0424,  0.0363,
         0.0405,  0.0006, -0.0422,  0.0363, -0.0298, -0.0319, -0.0131,  0.0021,
         0.0276, -0.0302, -0.0350, -0.0433,  0.0185,  0.0263,  0.0307, -0.0093,
         0.0377,  0.0031, -0.0115, -0.0297, -0.0327,  0.0103,  0.0179, -0.0071,
         0.0029, -0.0345, -0.0335, -0.0184,  0.0426,  0.0169,  0.0039, -0.0071,
         0.0421, -0.0185, -0.0235, -0.0288, -0.0305, -0.0199,  0.0091, -0.0162,
        -0.0269,  0.0172,  0.0330, -0.0416, -0.0347, -0.0392, -0.0148, -0.0167])
    '''
    # 输出可以进行梯度更新的权重参数名称——————————因为前面修改了所有的参数更新为False,此判断无输出
    # if param.requires_grad == True:
    #     print('\t', name)


def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # 加载预训练模型
    # resnet有18、50、101、152等层
    # model_name.resnet18只有网络结构,没有预训练参数。pretrained预训练模型,pretrained=True下载resnet18到C:/user/cache
    model_ft = model_name.resnet18(pretrained=use_pretrained)

    # 冻结网络结构中所有的参数更新
    set_parameter_requires_grad(model_ft, feature_extract)

    # 找到全连接层的输入参数:512
    # resnet18的全连接层:(fc): Linear(in_features=512, out_features=1000, bias=True)
    num_ftrs = model_ft.fc.in_features
    # num_classes自己的任务类别数,覆盖resnet18中的全连接输出,此全连接网络中的参数可以进行梯度更新
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    # # 输入大小根据自己配置来
    # input_size = 64
    # return model_ft, input_size
    return model_ft

# model_ft, input_size = initialize_model(models, 102, feature_extract, use_pretrained=True)
model_ft = initialize_model(models, 102, feature_extract, use_pretrained=True)

#GPU还是CPU计算
model_ft = model_ft.to(device)

# 模型保存,名字自己起——————网络结构和权重参数
filename = 'D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/best_my_false.pt'
filename_new = path + '/best_my_true.pt'

# 是否训练所有层
# 将model_ft的所有权重参数保存到params_to_update
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        # 此项目中只有新加的全连接层param.requires_grad为True
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)
else:
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t", name)

# 优化器设置,只训练params_to_update中的参数
optimizer_ft = optim.Adam(params_to_update, lr=1e-3)
# 学习率衰减策略:学习率每step_size个epoch衰减成原来的gamma倍
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
# 损失函数
criterion = nn.CrossEntropyLoss()

# 加载之前训练好的权重参数
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

# 训练函数
def train_model(model, dataloaders, criterion, optimizer,  filename_new, num_epochs=25):
    # 计算训练起始时间
    since = time.time()
    # 记录训练最好的那一次的准确率
    best_acc = 0
    # 判断模型放到CPU或者GPU
    model.to(device)

    # 训练过程中打印一堆损失和指标
    # 验证准确率
    val_acc_history = []
    # 训练准确率
    train_acc_history = []
    # 训练损失
    train_losses = []
    # 验证损失
    valid_losses = []

    # 当前学习率         optimizer.param_groups是一个字典结构
    LRs = [optimizer.param_groups[0]['lr']]
    # print(optimizer.param_groups)
    '''
    [{'params': [Parameter containing:
tensor([[-0.0065, -0.0276,  0.0032,  ..., -0.0383, -0.0091,  0.0323],
        [ 0.0141,  0.0429, -0.0220,  ...,  0.0234,  0.0301, -0.0281],
        [ 0.0064, -0.0427,  0.0039,  ...,  0.0214, -0.0171, -0.0016],
        ...,
        [-0.0355,  0.0126, -0.0099,  ..., -0.0322, -0.0201,  0.0245],
        [-0.0127,  0.0114, -0.0213,  ...,  0.0270, -0.0070, -0.0315],
        [-0.0226, -0.0235,  0.0262,  ..., -0.0109,  0.0241,  0.0084]],
       requires_grad=True), Parameter containing:
tensor([-0.0118,  0.0029, -0.0184,  0.0226,  0.0082, -0.0320, -0.0046,  0.0358,
        -0.0234,  0.0430,  0.0245,  0.0431, -0.0127, -0.0231,  0.0230,  0.0357,
        -0.0181,  0.0389,  0.0127,  0.0343,  0.0044,  0.0217, -0.0323, -0.0211,
         0.0309,  0.0416, -0.0317, -0.0248,  0.0093, -0.0324, -0.0115,  0.0181,
        -0.0190, -0.0005,  0.0418, -0.0369, -0.0144, -0.0229, -0.0295, -0.0048,
         0.0088, -0.0371, -0.0203, -0.0163,  0.0073,  0.0044, -0.0410, -0.0289,
        -0.0305, -0.0363,  0.0409,  0.0364,  0.0082,  0.0419, -0.0063, -0.0100,
         0.0008, -0.0270, -0.0163,  0.0059, -0.0100,  0.0252,  0.0183, -0.0160,
         0.0027,  0.0347, -0.0131, -0.0292, -0.0225, -0.0183,  0.0326, -0.0062,
        -0.0422, -0.0220, -0.0410, -0.0408,  0.0405, -0.0046, -0.0339,  0.0411,
        -0.0015, -0.0371, -0.0152,  0.0244, -0.0128, -0.0117, -0.0275,  0.0333,
         0.0033,  0.0276,  0.0302, -0.0367,  0.0236,  0.0409,  0.0192, -0.0411,
        -0.0224, -0.0200, -0.0321,  0.0120, -0.0427,  0.0333],
       requires_grad=True)], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 
       'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.01}]
    '''
    # print(optimizer)
    '''
    Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        initial_lr: 0.01
        lr: 0.01
        maximize: False
        weight_decay: 0
    )
    '''

    # 最好的那次模型,后续会变的,先初始化————————复制当前的权重参数,model.state_dict()模型当前权重参数
    best_model_wts = copy.deepcopy(model.state_dict())

    # 一个个epoch来遍历
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 训练和验证
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 训练
            else:
                model.eval()  # 验证

            # 初始化损失和预测正确个数
            running_loss = 0.0
            running_corrects = 0

            # 把数据都取个遍——————dataloaders字典结构,根据phase关键字决定取哪一部分
            for inputs, labels in dataloaders[phase]:
                # to(device)数据放到你的CPU或GPU
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 梯度清零
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                # 102个类别中概率最大值的下标
                _, preds = torch.max(outputs, 1)
                # 训练阶段更新权重
                if phase == 'train':
                    loss.backward()
                    optimizer.step() #完成一次迭代

                # 累加计算总损失和总准确率
                # input格式为(batch, c, h, w),inputs.size(0)表示batch那个维度
                running_loss += loss.item() * inputs.size(0)
                # 预测结果最大的和真实值是否一致
                running_corrects += torch.sum(preds == labels.data)

            # 计算每个epoch的损失和准确率
            epoch_loss = running_loss / len(dataloaders[phase].dataset)  # 算平均
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            # 一个epoch需要多少时间
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # 得到最好那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),  # 字典里key就是各层的名字,值就是训练好的权重
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(state, filename_new)

            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                # scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)

        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()
        scheduler.step()  # 学习率衰减

    # 总体运行花了多少时间
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs

# 训练模型
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, filename_new, num_epochs=20)


训练结果展示:
起始训练结果
结束训练结果

四、模型测试

# 加载最佳训练模型
checkpoint = torch.load(filename_new)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

# 随机得到一个batch的验证数据进行测试
dataiter = iter(dataloaders['valid'])
images, labels = next(dataiter)
model_ft.eval()
if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)

# 得到概率最大的一个
_, preds_tensor = torch.max(output, 1)
# 判断数据是否在GPU,在的话数据转cpu中转ndarray类型,在的话直接转ndarray类型
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())


# 展示预测结果
def im_convert(tensor):
    """ 展示数据"""
    # 数据从cpu中克隆一份出来
    image = tensor.to("cpu").clone().detach()
    # 数据从tensor格式转为ndarray
    # numpy.squeeze(a, axis=None),用于从数组的形状中删除单维条目,其中a表示输入的数组,axis用于指定需要删除的维度。如果axis为空,则删除所有单维度的条目。
    image = image.numpy().squeeze()
    # PIL工具包
    # tensor中数据格式为c*h*w,正常数据格式为h*w*c,transpose()将0、1、2代表tensor中三个维度c、h、w,转换为1、2、0即h、w、c格式
    image = image.transpose(1, 2, 0)
    # 反标准化操作,均值u,标准差b,x = x*b+u
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    # clip()将数组中的元素值限制在给定的最小值和最大值之间,超出这个范围的值会被截断到最小值或最大值。
    image = image.clip(0, 1)
    return image


fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 2

for idx in range(columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
                 color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

测试结果展示:
随机测试结果

  • 6
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值