小白学Pytorch使用(4-1):花数据集分类——迁移学习

任务背景

利用resnet18网络结构及预训练模型参数进行102类别的花数据集分类,迁移学习冻结resnet18输出层外的权重参数更新,保存最好的训练模型pt文件。
数据如下:
花数据集+json文件+2个迁移学习模型

一、导入库

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
# torchvision中的transforms模块自带数据增强、数据预处理功能;models预训练模型,如resnet模型;datasets文件夹
from torchvision import transforms, models, datasets
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import copy
import json
from PIL import Image

二、数据导入

# 导入数据路径——改为自己的数据路径
data_dir = r'D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

三、数据预处理

# 数据集量较少,需进行数据增强(Data Augmentation):旋转、裁剪、翻转(水平/垂直)、平移
data_transforms = {
    'train':
        # Compose()顺序执行以下操作
        transforms.Compose([
        # 数据集尺寸不同,统一尺寸,可正方形(常用)、长方形,一般64、128、224、256
        # 数据越小易损失特征,但训练速度加快;CPU一般96/64
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),  #随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),      #从中心开始裁剪,实际输入网络的图像大小以裁剪后大小为准,如裁剪后为64*64,则输入大小为64*64
        transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转 选择一个概率概率,每张图都有50%可能性水平翻转
        transforms.RandomVerticalFlip(p=0.5),   #随机垂直翻转,每张图都有50%可能性垂直翻转
        # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#(极端光照条件下使用,一般不使用)参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        # transforms.RandomGrayscale(p=0.025),#(一般不使用)概率转换成灰度率,3通道就是R=G=B,转成RRR/GGG/BBB
        transforms.ToTensor(),      #数据转换为tensor格式
        # 标准化操作。由于数据量较少不具有代表性,选用自己的均值标准差结果不稳定,因此选择大数据中提供的标准差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#r、g、b三通道的均值u,标准差b,(x-u)/b
    ]),

    # 验证集测试模型实际训练结果,不需要数据增强
    'valid':
        transforms.Compose([
        # 与训练集输入尺寸(裁剪后尺寸)相同
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        # 与训练集均值标准差相同
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 由于输入图像尺寸较小,batch_size可以大些。考虑电脑显存问题
batch_size = 128

# 将训练集和验证集文件夹与图像预处理操作对应起来————————image_datasets为字典类型,datasets数据以文件夹形式处理(数据集中的数据按类别划分了单独文件夹,因此以文件夹形式处理)
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
# print(image_datasets)
'''
{'train': Dataset ImageFolder
    Number of datapoints: 6552
    Root location: D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data\train
    StandardTransform
Transform: Compose(
               Resize(size=[96, 96], interpolation=bilinear, max_size=None, antialias=True)
               RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
               CenterCrop(size=(64, 64))
               RandomHorizontalFlip(p=0.5)
               RandomVerticalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           ), 'valid': Dataset ImageFolder
    Number of datapoints: 818
    Root location: D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data\valid
    StandardTransform
Transform: Compose(
               Resize(size=[64, 64], interpolation=bilinear, max_size=None, antialias=True)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )}
'''

# 将训练集和验证集文件夹中的数据打乱划分batch——————dataloaders为字典形式
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
# print(dataloaders)
# {'train': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07130>, 'valid': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07100>}

# 计算训练集和验证集分别数据数目
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
# 获取分类标签
class_names = image_datasets['train'].classes
# print(class_names)
'''
['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', 
    '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', 
    '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', 
    '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89',
    '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']
'''

# 打开cat_to_name.json文件,文件中有数字对应的实际类别名称——数据标签文件
with open('D:/咕泡人工智能-配套资料\配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)
    # print(cat_to_name)
'''
{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers',
 '7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger', 
 '67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist', 
 '9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy', 
 '15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower',
 '62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper', 
 '42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 'sunflower', '66': 'osteospermum', '70': 'tree poppy', '85': 'desert-rose', 
 '99': 'bromelia', '87': 'magnolia', '5': 'english marigold', '92': 'bee balm', '28': 'stemless gentian', '97': 'mallow', '57': 'gaura', '40': 'lenten rose', 
 '47': 'marigold', '59': 'orange dahlia', '48': 'buttercup', '55': 'pelargonium', '36': 'ruby-lipped cattleya', '91': 'hippeastrum', '29': 'artichoke', '71': 'gazania', 
 '90': 'canna lily', '18': 'peruvian lily', '98': 'mexican petunia', '8': 'bird of paradise', '30': 'sweet william', '17': 'purple coneflower', '52': 'wild pansy', 
 '84': 'columbine', '12': "colt's foot", '11': 'snapdragon', '96': 'camellia', '23': 'fritillary', '50': 'common dandelion', '44': 'poinsettia', '53': 'primula', 
 '72': 'azalea', '65': 'californian poppy', '80': 'anthurium', '76': 'morning glory', '37': 'cape flower', '56': 'bishop of llandaff', '60': 'pink-yellow dahlia', 
 '82': 'clematis', '58': 'geranium', '75': 'thorn apple', '41': 'barbeton daisy', '95': 'bougainvillea', '43': 'sword lily', '83': 'hibiscus', '78': 'lotus lotus', 
 '88': 'cyclamen', '94': 'foxglove', '81': 'frangipani', '74': 'rose', '89': 'watercress', '73': 'water lily', '46': 'wallflower', '77': 'passion flower', 
 '51': 'petunia'}
'''

四、导入Resnet网络——迁移学习

# 迁移学习:使用前人的网络结构和模型做训练
# 数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层
# 数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练
# 数据量较大时,整个模型不进行冻结,全部更新训练

# 加载models中提供的模型,并且直接用训练好的权重当作初始化参数
#可选网络结构比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],resnet网络效果较好
model_name = 'resnet'

# 是否用人家训练好的特征来做
# 此项目冻结输出层前面所有部分,不进行训练更新
feature_extract = True

# 是否用GPU训练——————固定写法
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 使用18层的resnet网络结构,18层的能快点,条件好点的也可以选152
model_ft = models.resnet18()
# resnet18网络结构
# print(model_ft)
'''
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  # resnet18输出层为1000个类别,此次项目预测类别为102,后续修改输出层
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''

# 自定义函数判断模型参数是否要进行更新
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        # 对18层resnet网络结构中的权重参数进行遍历
        for param in model.parameters():
            # requires_grad反向传播时是否更新参数,False不进行更新
            param.requires_grad = False

# for param in model_ft.parameters():
#     param.requires_grad = False
# for name, param in model_ft.named_parameters():
#     # 输出所有权重参数名称及参数值
#     print(name)
#     print(param)
# 输出层权重参数如下:
'''
fc.weight
Parameter containing:
tensor([[ 0.0018, -0.0213, -0.0421,  ...,  0.0064, -0.0055, -0.0221],
        [-0.0298, -0.0036,  0.0277,  ...,  0.0181, -0.0311, -0.0340],
        [-0.0434, -0.0196, -0.0026,  ...,  0.0229,  0.0110,  0.0345],
        ...,
        [-0.0097,  0.0290,  0.0391,  ..., -0.0091, -0.0028, -0.0373],
        [ 0.0073,  0.0256,  0.0238,  ...,  0.0004, -0.0275, -0.0347],
        [-0.0356, -0.0144, -0.0418,  ...,  0.0020, -0.0338,  0.0351]])
fc.bias
Parameter containing:
tensor([ 0.0226,  0.0271, -0.0001, -0.0043, -0.0339, -0.0216,  0.0056, -0.0067,
        -0.0024,  0.0062, -0.0425,  0.0148,  0.0340, -0.0365,  0.0084,  0.0407,
        -0.0335,  0.0133,  0.0155, -0.0303, -0.0304,  0.0291, -0.0440, -0.0278,
        -0.0324, -0.0041,  0.0068, -0.0349, -0.0380,  0.0169,  0.0203, -0.0439,
         0.0185, -0.0206,  0.0164, -0.0020,  0.0067, -0.0101, -0.0166,  0.0155,
        -0.0302,  0.0095, -0.0119,  0.0124,  0.0065,  0.0435, -0.0298, -0.0022,
        -0.0379,  0.0406, -0.0065, -0.0210,  0.0151,  0.0111, -0.0251, -0.0374,
        -0.0246, -0.0052, -0.0139, -0.0246,  0.0008,  0.0318,  0.0358,  0.0370,
         0.0277, -0.0124,  0.0391, -0.0158,  0.0052, -0.0437,  0.0076, -0.0352,
        -0.0045,  0.0270, -0.0332, -0.0177, -0.0420, -0.0116, -0.0224,  0.0410,
        -0.0063,  0.0408, -0.0062, -0.0034, -0.0084, -0.0306,  0.0077,  0.0284,
         0.0437, -0.0336, -0.0037, -0.0261, -0.0234, -0.0166,  0.0153, -0.0213,
        -0.0396,  0.0089,  0.0307,  0.0027, -0.0179, -0.0098,  0.0343,  0.0375,
        -0.0432,  0.0126, -0.0060,  0.0370, -0.0030,  0.0219,  0.0140, -0.0165,
         0.0409, -0.0134, -0.0352, -0.0265, -0.0312, -0.0205,  0.0268,  0.0429,
        -0.0260,  0.0252,  0.0016, -0.0357,  0.0052, -0.0040,  0.0173, -0.0442,
        -0.0089, -0.0116,  0.0166, -0.0381,  0.0143,  0.0032, -0.0010, -0.0092,
         0.0130, -0.0224, -0.0258,  0.0187,  0.0179,  0.0061,  0.0297,  0.0404,
         0.0218,  0.0116,  0.0345, -0.0209, -0.0256,  0.0091, -0.0302, -0.0306,
         0.0239,  0.0412,  0.0408,  0.0253, -0.0016, -0.0182, -0.0015,  0.0042,
         0.0150, -0.0024,  0.0138, -0.0164,  0.0076, -0.0181,  0.0339,  0.0179,
        -0.0289, -0.0327, -0.0399,  0.0419,  0.0386, -0.0345, -0.0321, -0.0413,
         0.0390, -0.0339,  0.0359,  0.0012,  0.0298,  0.0134, -0.0128, -0.0251,
         0.0435, -0.0231, -0.0432,  0.0385, -0.0423, -0.0137,  0.0170, -0.0412,
        -0.0413,  0.0114,  0.0428, -0.0425, -0.0089, -0.0290, -0.0112,  0.0277,
        -0.0109,  0.0083,  0.0324, -0.0163, -0.0389, -0.0206,  0.0052, -0.0091,
         0.0151,  0.0244,  0.0010, -0.0173, -0.0015, -0.0437, -0.0377, -0.0107,
        -0.0185,  0.0059, -0.0198, -0.0395, -0.0129, -0.0363,  0.0408,  0.0418,
        -0.0230, -0.0122, -0.0168,  0.0281,  0.0338,  0.0321,  0.0219, -0.0041,
         0.0307, -0.0244, -0.0007, -0.0307,  0.0387,  0.0051,  0.0417, -0.0241,
        -0.0165,  0.0186,  0.0210,  0.0373,  0.0192,  0.0415, -0.0230, -0.0091,
        -0.0324, -0.0416,  0.0304,  0.0065,  0.0398,  0.0036, -0.0232, -0.0392,
         0.0109, -0.0108, -0.0320, -0.0032, -0.0138,  0.0357, -0.0247,  0.0363,
        -0.0185, -0.0197, -0.0068, -0.0120, -0.0377, -0.0101,  0.0210, -0.0243,
         0.0269,  0.0128, -0.0142, -0.0385,  0.0185,  0.0233, -0.0051, -0.0172,
         0.0224, -0.0434,  0.0078,  0.0159, -0.0201, -0.0363, -0.0246,  0.0044,
        -0.0306, -0.0377, -0.0313,  0.0366,  0.0368, -0.0303,  0.0066,  0.0322,
        -0.0143,  0.0343,  0.0233, -0.0337, -0.0211, -0.0060, -0.0167, -0.0189,
        -0.0236,  0.0292,  0.0194,  0.0372, -0.0055,  0.0430,  0.0243, -0.0126,
         0.0208,  0.0273,  0.0145,  0.0269,  0.0020, -0.0070, -0.0102,  0.0016,
        -0.0191,  0.0397, -0.0001, -0.0044, -0.0360,  0.0095,  0.0357,  0.0089,
         0.0235, -0.0244,  0.0088,  0.0222,  0.0259,  0.0096,  0.0189,  0.0390,
         0.0401, -0.0208, -0.0358, -0.0197, -0.0248, -0.0088,  0.0085,  0.0018,
        -0.0132,  0.0289,  0.0287, -0.0400,  0.0063, -0.0054,  0.0399,  0.0123,
         0.0102, -0.0326,  0.0194, -0.0049,  0.0104, -0.0171, -0.0080,  0.0429,
        -0.0056, -0.0298,  0.0064,  0.0341, -0.0191,  0.0132, -0.0174,  0.0435,
         0.0035,  0.0093, -0.0321, -0.0366,  0.0307,  0.0088, -0.0395,  0.0357,
         0.0032, -0.0149, -0.0247,  0.0124,  0.0436,  0.0126, -0.0156,  0.0050,
         0.0109,  0.0183, -0.0404, -0.0018, -0.0104, -0.0395,  0.0390, -0.0306,
         0.0261, -0.0244, -0.0253, -0.0329,  0.0334,  0.0429, -0.0138, -0.0190,
        -0.0235, -0.0204, -0.0393,  0.0217, -0.0332, -0.0438, -0.0294, -0.0158,
        -0.0103,  0.0238, -0.0419, -0.0408,  0.0113,  0.0366,  0.0221, -0.0190,
        -0.0244,  0.0335,  0.0102, -0.0101, -0.0111, -0.0284, -0.0155, -0.0114,
         0.0137,  0.0019, -0.0006, -0.0074,  0.0137, -0.0260, -0.0037,  0.0199,
         0.0155, -0.0296,  0.0173,  0.0224,  0.0091, -0.0167,  0.0004,  0.0206,
        -0.0237, -0.0195,  0.0387, -0.0045,  0.0088,  0.0261, -0.0418, -0.0144,
        -0.0375, -0.0106, -0.0354,  0.0411, -0.0053, -0.0248, -0.0010,  0.0323,
        -0.0203,  0.0012,  0.0204, -0.0320,  0.0321,  0.0137,  0.0064, -0.0329,
         0.0051, -0.0340,  0.0171,  0.0422,  0.0266,  0.0238, -0.0164,  0.0103,
        -0.0413, -0.0355,  0.0127,  0.0207, -0.0240,  0.0398,  0.0323,  0.0217,
         0.0030,  0.0396,  0.0327, -0.0060,  0.0312, -0.0117, -0.0079,  0.0095,
         0.0423,  0.0010,  0.0018,  0.0233,  0.0434, -0.0210,  0.0049, -0.0072,
         0.0031,  0.0052, -0.0292, -0.0217, -0.0253,  0.0274, -0.0230, -0.0342,
        -0.0149,  0.0137, -0.0057,  0.0344, -0.0327,  0.0370,  0.0142, -0.0194,
         0.0109,  0.0366, -0.0046, -0.0203,  0.0088, -0.0117,  0.0263,  0.0020,
        -0.0335,  0.0387, -0.0196,  0.0386, -0.0433, -0.0012,  0.0138, -0.0383,
         0.0059, -0.0313, -0.0404, -0.0103, -0.0105, -0.0196, -0.0297,  0.0331,
         0.0372, -0.0275,  0.0007,  0.0266,  0.0006,  0.0255, -0.0035, -0.0365,
         0.0165, -0.0423,  0.0054,  0.0041,  0.0026, -0.0338,  0.0436, -0.0385,
         0.0346, -0.0295, -0.0315,  0.0096,  0.0376,  0.0166, -0.0351,  0.0077,
         0.0028,  0.0139,  0.0394,  0.0115, -0.0231, -0.0134,  0.0167,  0.0160,
        -0.0003, -0.0152,  0.0399,  0.0367, -0.0424,  0.0104, -0.0025, -0.0118,
         0.0020, -0.0115, -0.0253, -0.0290,  0.0042, -0.0025, -0.0155,  0.0101,
        -0.0170,  0.0135,  0.0283,  0.0441,  0.0294, -0.0083, -0.0428, -0.0267,
        -0.0247, -0.0344,  0.0177,  0.0173,  0.0029, -0.0369, -0.0150, -0.0193,
         0.0428, -0.0401,  0.0377,  0.0138, -0.0136, -0.0104,  0.0325,  0.0335,
        -0.0152, -0.0014, -0.0287,  0.0375, -0.0426,  0.0393,  0.0016, -0.0244,
         0.0394,  0.0212, -0.0019, -0.0024, -0.0212,  0.0249, -0.0244, -0.0144,
         0.0227,  0.0073,  0.0412, -0.0194, -0.0300,  0.0084, -0.0273,  0.0157,
        -0.0270, -0.0411,  0.0153, -0.0040, -0.0268,  0.0048,  0.0164, -0.0165,
        -0.0089,  0.0328,  0.0345, -0.0013, -0.0226, -0.0097, -0.0234,  0.0156,
         0.0236,  0.0076, -0.0349, -0.0442, -0.0293,  0.0132, -0.0420, -0.0020,
         0.0059, -0.0231, -0.0187, -0.0304,  0.0270,  0.0382,  0.0407, -0.0244,
         0.0394, -0.0204, -0.0356, -0.0287, -0.0221, -0.0330,  0.0249,  0.0247,
         0.0250, -0.0180, -0.0182, -0.0358, -0.0196, -0.0441,  0.0407, -0.0027,
         0.0177,  0.0167, -0.0258, -0.0097, -0.0210, -0.0106, -0.0331,  0.0008,
        -0.0392, -0.0378,  0.0418,  0.0093, -0.0442,  0.0419, -0.0233,  0.0076,
         0.0394,  0.0428,  0.0027, -0.0117, -0.0161,  0.0153,  0.0035,  0.0222,
         0.0375, -0.0300,  0.0200,  0.0285, -0.0120,  0.0074,  0.0054, -0.0117,
        -0.0161, -0.0088,  0.0428,  0.0103,  0.0136,  0.0376, -0.0172,  0.0086,
         0.0183,  0.0381,  0.0109, -0.0177, -0.0416, -0.0351,  0.0338, -0.0404,
         0.0324,  0.0041, -0.0172,  0.0002, -0.0415,  0.0423,  0.0172, -0.0068,
         0.0370,  0.0321,  0.0040,  0.0401,  0.0030,  0.0238,  0.0102, -0.0437,
         0.0204,  0.0432,  0.0399, -0.0186, -0.0067,  0.0205, -0.0098,  0.0009,
         0.0019,  0.0317, -0.0237,  0.0062,  0.0097, -0.0312, -0.0033,  0.0028,
        -0.0021,  0.0006, -0.0320, -0.0196,  0.0363,  0.0285,  0.0321, -0.0132,
         0.0344, -0.0232,  0.0379, -0.0166,  0.0311, -0.0174,  0.0431,  0.0006,
        -0.0066, -0.0344, -0.0076,  0.0245,  0.0286, -0.0388,  0.0114,  0.0204,
         0.0137,  0.0387,  0.0206, -0.0392, -0.0109,  0.0375,  0.0269,  0.0232,
        -0.0362,  0.0235, -0.0137,  0.0303,  0.0389, -0.0068,  0.0306,  0.0273,
         0.0264, -0.0074,  0.0315, -0.0291, -0.0027, -0.0061,  0.0188, -0.0123,
        -0.0360, -0.0266,  0.0292,  0.0248,  0.0127, -0.0251, -0.0426,  0.0066,
        -0.0005, -0.0162, -0.0236, -0.0330,  0.0339,  0.0319,  0.0135,  0.0260,
        -0.0389, -0.0375, -0.0192, -0.0079, -0.0066, -0.0261, -0.0441, -0.0042,
         0.0086,  0.0291,  0.0283,  0.0028, -0.0137, -0.0218,  0.0109, -0.0052,
        -0.0213, -0.0108, -0.0354, -0.0012,  0.0006, -0.0111, -0.0066,  0.0401,
        -0.0313,  0.0203,  0.0060,  0.0339,  0.0072,  0.0148,  0.0047,  0.0363,
        -0.0117,  0.0338,  0.0238, -0.0367,  0.0093,  0.0048,  0.0222,  0.0164,
         0.0399,  0.0005,  0.0056,  0.0023, -0.0330,  0.0241, -0.0403,  0.0047,
         0.0312, -0.0148,  0.0146,  0.0268,  0.0439, -0.0265, -0.0044,  0.0441,
         0.0395, -0.0164,  0.0117,  0.0343,  0.0355,  0.0243, -0.0091,  0.0083,
        -0.0213, -0.0196, -0.0082,  0.0070,  0.0403, -0.0407,  0.0270,  0.0375,
         0.0207,  0.0207, -0.0049,  0.0020,  0.0429, -0.0432, -0.0368, -0.0040,
         0.0035,  0.0114, -0.0345, -0.0286,  0.0232,  0.0342,  0.0437,  0.0193,
         0.0030,  0.0375,  0.0161,  0.0172, -0.0069, -0.0118, -0.0235, -0.0155,
        -0.0029, -0.0035,  0.0372, -0.0343,  0.0183, -0.0057, -0.0093, -0.0322,
        -0.0303,  0.0275, -0.0364, -0.0240, -0.0090, -0.0058, -0.0055,  0.0315,
        -0.0020,  0.0268, -0.0305, -0.0286, -0.0083,  0.0015, -0.0226,  0.0249,
        -0.0133, -0.0359,  0.0393,  0.0058, -0.0354,  0.0011,  0.0424,  0.0363,
         0.0405,  0.0006, -0.0422,  0.0363, -0.0298, -0.0319, -0.0131,  0.0021,
         0.0276, -0.0302, -0.0350, -0.0433,  0.0185,  0.0263,  0.0307, -0.0093,
         0.0377,  0.0031, -0.0115, -0.0297, -0.0327,  0.0103,  0.0179, -0.0071,
         0.0029, -0.0345, -0.0335, -0.0184,  0.0426,  0.0169,  0.0039, -0.0071,
         0.0421, -0.0185, -0.0235, -0.0288, -0.0305, -0.0199,  0.0091, -0.0162,
        -0.0269,  0.0172,  0.0330, -0.0416, -0.0347, -0.0392, -0.0148, -0.0167])
    '''
    # 输出可以进行梯度更新的权重参数名称——————————因为前面修改了所有的参数更新为False,此判断无输出
    # if param.requires_grad == True:
    #     print('\t', name)

五、模型训练及参数配置

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # 加载预训练模型
    # resnet有18、50、101、152等层
    # model_name.resnet18只有网络结构,没有预训练参数。pretrained是否加载预训练模型,pretrained=True下载resnet18到C:/user/cache
    model_ft = model_name.resnet18(pretrained=use_pretrained)

    # 冻结网络结构中所有的参数更新
    set_parameter_requires_grad(model_ft, feature_extract)

    # 找到全连接层的输入参数:512
    # resnet18的全连接层:(fc): Linear(in_features=512, out_features=1000, bias=True)
    num_ftrs = model_ft.fc.in_features
    # num_classes自己的任务类别数,覆盖resnet18中的全连接输出,此全连接网络中的参数可以进行梯度更新
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    # # 输入大小根据自己配置来
    # input_size = 64
    # return model_ft, input_size
    return model_ft

# input_size好像没用处,删去
# model_ft, input_size = initialize_model(models, 102, feature_extract, use_pretrained=True)
model_ft = initialize_model(models, 102, feature_extract, use_pretrained=True)

#GPU还是CPU计算
model_ft = model_ft.to(device)

# 模型保存路径,名字自己起——————网络结构和权重参数
filename = 'D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/best_my.pt'

# 是否训练所有层
# 将model_ft的所有权重参数保存到params_to_update
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        # 此项目中只有新加的全连接层param.requires_grad为True,进行参数更新
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)
else:
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t", name)

# 优化器设置,只训练params_to_update中的参数
optimizer_ft = optim.Adam(params_to_update, lr=1e-2)
# 学习率衰减策略:学习率每step_size个epoch衰减成原来的gamma倍
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
# 损失函数
criterion = nn.CrossEntropyLoss()

# 训练函数
def train_model(model, dataloaders, criterion, optimizer,  filename, num_epochs=25):
    # 计算训练起始时间
    since = time.time()
    # 记录训练最好的那一次的准确率
    best_acc = 0
    # 判断模型放到CPU或者GPU
    model.to(device)

    # 训练过程中打印一堆损失和指标
    # 验证准确率
    val_acc_history = []
    # 训练准确率
    train_acc_history = []
    # 训练损失
    train_losses = []
    # 验证损失
    valid_losses = []

    # 当前学习率         optimizer.param_groups是一个字典结构
    LRs = [optimizer.param_groups[0]['lr']]
    # print(optimizer.param_groups)
    '''
    [{'params': [Parameter containing:
tensor([[-0.0065, -0.0276,  0.0032,  ..., -0.0383, -0.0091,  0.0323],
        [ 0.0141,  0.0429, -0.0220,  ...,  0.0234,  0.0301, -0.0281],
        [ 0.0064, -0.0427,  0.0039,  ...,  0.0214, -0.0171, -0.0016],
        ...,
        [-0.0355,  0.0126, -0.0099,  ..., -0.0322, -0.0201,  0.0245],
        [-0.0127,  0.0114, -0.0213,  ...,  0.0270, -0.0070, -0.0315],
        [-0.0226, -0.0235,  0.0262,  ..., -0.0109,  0.0241,  0.0084]],
       requires_grad=True), Parameter containing:
tensor([-0.0118,  0.0029, -0.0184,  0.0226,  0.0082, -0.0320, -0.0046,  0.0358,
        -0.0234,  0.0430,  0.0245,  0.0431, -0.0127, -0.0231,  0.0230,  0.0357,
        -0.0181,  0.0389,  0.0127,  0.0343,  0.0044,  0.0217, -0.0323, -0.0211,
         0.0309,  0.0416, -0.0317, -0.0248,  0.0093, -0.0324, -0.0115,  0.0181,
        -0.0190, -0.0005,  0.0418, -0.0369, -0.0144, -0.0229, -0.0295, -0.0048,
         0.0088, -0.0371, -0.0203, -0.0163,  0.0073,  0.0044, -0.0410, -0.0289,
        -0.0305, -0.0363,  0.0409,  0.0364,  0.0082,  0.0419, -0.0063, -0.0100,
         0.0008, -0.0270, -0.0163,  0.0059, -0.0100,  0.0252,  0.0183, -0.0160,
         0.0027,  0.0347, -0.0131, -0.0292, -0.0225, -0.0183,  0.0326, -0.0062,
        -0.0422, -0.0220, -0.0410, -0.0408,  0.0405, -0.0046, -0.0339,  0.0411,
        -0.0015, -0.0371, -0.0152,  0.0244, -0.0128, -0.0117, -0.0275,  0.0333,
         0.0033,  0.0276,  0.0302, -0.0367,  0.0236,  0.0409,  0.0192, -0.0411,
        -0.0224, -0.0200, -0.0321,  0.0120, -0.0427,  0.0333],
       requires_grad=True)], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 
       'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.01}]
    '''
    # 查看优化器参数
    # print(optimizer)
    '''
    Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        initial_lr: 0.01
        lr: 0.01
        maximize: False
        weight_decay: 0
    )
    '''

    # 最好的那次模型,后续会变的,先初始化————————复制当前的权重参数,model.state_dict()模型当前权重参数
    best_model_wts = copy.deepcopy(model.state_dict())

    # 一个个epoch来遍历
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 训练和验证
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 训练
            else:
                model.eval()  # 验证

            # 初始化损失和预测正确个数
            running_loss = 0.0
            running_corrects = 0

            # 把数据都取个遍——————dataloaders字典结构,根据phase关键字决定取哪一部分
            for inputs, labels in dataloaders[phase]:
                # to(device)数据放到你的CPU或GPU
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 梯度清零
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                # 102个类别中概率最大值的下标
                _, preds = torch.max(outputs, 1)
                # 训练阶段更新权重
                if phase == 'train':
                    loss.backward()
                    optimizer.step() #完成一次迭代

                # 累加计算总损失和总准确率
                # input格式为(batch, c, h, w),inputs.size(0)表示batch那个维度
                running_loss += loss.item() * inputs.size(0)
                # 预测结果最大的和真实值是否一致
                running_corrects += torch.sum(preds == labels.data)

            # 计算每个epoch的损失和准确率
            epoch_loss = running_loss / len(dataloaders[phase].dataset)  # 算平均
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            # 一个epoch需要多少时间
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # 得到最好那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),  # 字典里key就是各层的名字,值就是训练好的权重
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(state, filename)

            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                # scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)

        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()
        scheduler.step()  # 学习率衰减

    # 总体运行花了多少时间
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs

# 训练模型
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, filename, num_epochs=20)

训练过程展示:
训练起始结果
训练结束结果

六、验证最佳训练模型

# 加载最佳训练模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

# 随机得到一个batch的验证数据进行测试
dataiter = iter(dataloaders['valid'])
images, labels = next(dataiter)
model_ft.eval()
if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)

# 得到概率最大的一个
_, preds_tensor = torch.max(output, 1)
# 判断数据是否在GPU,在的话数据转cpu中转ndarray类型,在的话直接转ndarray类型
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())


# 展示预测结果
def im_convert(tensor):
    """ 展示数据"""
    # 数据从cpu中克隆一份出来
    image = tensor.to("cpu").clone().detach()
    # 数据从tensor格式转为ndarray
    # numpy.squeeze(a, axis=None),用于从数组的形状中删除单维条目,其中a表示输入的数组,axis用于指定需要删除的维度。如果axis为空,则删除所有单维度的条目。
    image = image.numpy().squeeze()
    # PIL工具包
    # tensor中数据格式为c*h*w,正常数据格式为h*w*c,transpose()将0、1、2代表tensor中三个维度c、h、w,转换为1、2、0即h、w、c格式
    image = image.transpose(1, 2, 0)
    # 反标准化操作,均值u,标准差b,x = x*b+u
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    print(image)
    # clip()将数组中的元素值限制在给定的最小值和最大值之间,超出这个范围的值会被截断到最小值或最大值。
    image = image.clip(0, 1)
    print(image)
    return image


fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 2

for idx in range(columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
                 color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

验证结果展示:
随机一个batch的八张验证结果
随机一个batch的八张验证结果

  • 14
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值