小白学Pytorch使用(4-2):花数据集分类——训练所有层

任务背景

利用resnet18网络结构及预训练模型参数进行102类别的花数据集分类,使用4-1中迁移学习获得的最佳训练模型作为初始模型,更新resnet18中的全部网络参数,保存最好的训练模型pt文件。
(与4-1更改部分文本会进行加粗处理,代码会进行注释处理)
数据如下:
数据集+模型

一、导入库

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
# torchvision中的transforms模块自带数据增强、数据预处理功能;models预训练模型,如resnet模型;datasets文件夹
from torchvision import transforms, models, datasets
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import copy
import json
from PIL import Image

二、数据导入

# 导入数据路径——改为自己的数据路径
data_dir = r'D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

三、数据预处理

# 数据集量较少,需进行数据增强(Data Augmentation):旋转、裁剪、翻转(水平/垂直)、平移
data_transforms = {
    'train':
        # Compose()顺序执行以下操作
        transforms.Compose([
        # 数据集尺寸不同,统一尺寸,可正方形(常用)、长方形,一般64、128、224、256
        # 数据越小易损失特征,但训练速度加快;CPU一般96/64
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),  #随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),      #从中心开始裁剪,实际输入网络的图像大小以裁剪后大小为准,如裁剪后为64*64,则输入大小为64*64
        transforms.RandomHorizontalFlip(p=0.5), #随机水平翻转 选择一个概率概率,每张图都有50%可能性水平翻转
        transforms.RandomVerticalFlip(p=0.5),   #随机垂直翻转,每张图都有50%可能性垂直翻转
        # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#(极端光照条件下使用,一般不使用)参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        # transforms.RandomGrayscale(p=0.025),#(一般不使用)概率转换成灰度率,3通道就是R=G=B,转成RRR/GGG/BBB
        transforms.ToTensor(),      #数据转换为tensor格式
        # 标准化操作。由于数据量较少不具有代表性,选用自己的均值标准差结果不稳定,因此选择大数据中提供的标准差
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#r、g、b三通道的均值u,标准差b,(x-u)/b
    ]),

    # 验证集测试模型实际训练结果,不需要数据增强
    'valid':
        transforms.Compose([
        # 与训练集输入尺寸(裁剪后尺寸)相同
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        # 与训练集均值标准差相同
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 由于输入图像尺寸较小,batch_size可以大些。考虑电脑显存问题
batch_size = 128

# 将训练集和验证集文件夹与图像预处理操作对应起来————————image_datasets为字典类型,datasets数据以文件夹形式处理(数据集中的数据按类别划分了单独文件夹,因此以文件夹形式处理)
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
# print(image_datasets)
'''
{'train': Dataset ImageFolder
    Number of datapoints: 6552
    Root location: D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data\train
    StandardTransform
Transform: Compose(
               Resize(size=[96, 96], interpolation=bilinear, max_size=None, antialias=True)
               RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
               CenterCrop(size=(64, 64))
               RandomHorizontalFlip(p=0.5)
               RandomVerticalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           ), 'valid': Dataset ImageFolder
    Number of datapoints: 818
    Root location: D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第五章:图像识别模型与训练策略(重点)\flower_data\valid
    StandardTransform
Transform: Compose(
               Resize(size=[64, 64], interpolation=bilinear, max_size=None, antialias=True)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )}
'''

# 将训练集和验证集文件夹中的数据打乱划分batch——————dataloaders为字典形式
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
# print(dataloaders)
# {'train': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07130>, 'valid': <torch.utils.data.dataloader.DataLoader object at 0x000002425BA07100>}

# 计算训练集和验证集分别数据数目
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
# 获取分类标签
class_names = image_datasets['train'].classes
# print(class_names)
'''
['1', '10', '100', '101', '102', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', 
    '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', 
    '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', 
    '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89',
    '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']
'''

# 打开cat_to_name.json文件,文件中有数字对应的实际类别名称——数据标签文件
with open('D:/咕泡人工智能-配套资料\配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json', 'r') as f:
    cat_to_name = json.load(f)
    # print(cat_to_name)
'''
{'21': 'fire lily', '3': 'canterbury bells', '45': 'bolero deep blue', '1': 'pink primrose', '34': 'mexican aster', '27': 'prince of wales feathers',
 '7': 'moon orchid', '16': 'globe-flower', '25': 'grape hyacinth', '26': 'corn poppy', '79': 'toad lily', '39': 'siam tulip', '24': 'red ginger', 
 '67': 'spring crocus', '35': 'alpine sea holly', '32': 'garden phlox', '10': 'globe thistle', '6': 'tiger lily', '93': 'ball moss', '33': 'love in the mist', 
 '9': 'monkshood', '102': 'blackberry lily', '14': 'spear thistle', '19': 'balloon flower', '100': 'blanket flower', '13': 'king protea', '49': 'oxeye daisy', 
 '15': 'yellow iris', '61': 'cautleya spicata', '31': 'carnation', '64': 'silverbush', '68': 'bearded iris', '63': 'black-eyed susan', '69': 'windflower',
 '62': 'japanese anemone', '20': 'giant white arum lily', '38': 'great masterwort', '4': 'sweet pea', '86': 'tree mallow', '101': 'trumpet creeper', 
 '42': 'daffodil', '22': 'pincushion flower', '2': 'hard-leaved pocket orchid', '54': 'sunflower', '66': 'osteospermum', '70': 'tree poppy', '85': 'desert-rose', 
 '99': 'bromelia', '87': 'magnolia', '5': 'english marigold', '92': 'bee balm', '28': 'stemless gentian', '97': 'mallow', '57': 'gaura', '40': 'lenten rose', 
 '47': 'marigold', '59': 'orange dahlia', '48': 'buttercup', '55': 'pelargonium', '36': 'ruby-lipped cattleya', '91': 'hippeastrum', '29': 'artichoke', '71': 'gazania', 
 '90': 'canna lily', '18': 'peruvian lily', '98': 'mexican petunia', '8': 'bird of paradise', '30': 'sweet william', '17': 'purple coneflower', '52': 'wild pansy', 
 '84': 'columbine', '12': "colt's foot", '11': 'snapdragon', '96': 'camellia', '23': 'fritillary', '50': 'common dandelion', '44': 'poinsettia', '53': 'primula', 
 '72': 'azalea', '65': 'californian poppy', '80': 'anthurium', '76': 'morning glory', '37': 'cape flower', '56': 'bishop of llandaff', '60': 'pink-yellow dahlia', 
 '82': 'clematis', '58': 'geranium', '75': 'thorn apple', '41': 'barbeton daisy', '95': 'bougainvillea', '43': 'sword lily', '83': 'hibiscus', '78': 'lotus lotus', 
 '88': 'cyclamen', '94': 'foxglove', '81': 'frangipani', '74': 'rose', '89': 'watercress', '73': 'water lily', '46': 'wallflower', '77': 'passion flower', 
 '51': 'petunia'}
'''

四、导入Resnet网络

# 迁移学习:使用前人的网络结构和模型做训练
# 数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层
# 数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练
# 数据量较大时,整个模型不进行冻结,全部更新训练

# 加载models中提供的模型,并且直接用训练好的权重当作初始化参数
#可选网络结构比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],resnet网络效果较好
model_name = 'resnet'

# 是否用人家训练好的特征来做
# 此项目冻结输出层前面所有部分,不进行训练更新
feature_extract = True

# 是否用GPU训练——————固定写法
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 使用18层的resnet网络结构,18层的能快点,条件好点的也可以选152
model_ft = models.resnet18()
# resnet18网络结构
# print(model_ft)
'''
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  # resnet18输出层为1000个类别,此次项目预测类别为102,后续修改输出层
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
'''

# 自定义函数判断模型参数是否要进行更新
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        # 对18层resnet网络结构中的权重参数进行遍历
        for param in model.parameters():
            # requires_grad反向传播时是否更新参数,True进行更新————由False改为True,进行参数迭代更新
            param.requires_grad = True

# for param in model_ft.parameters():
#     param.requires_grad = False
# for name, param in model_ft.named_parameters():
#     # 输出所有权重参数名称及参数值
#     print(name)
#     print(param)
# 输出层权重参数如下:
'''
fc.weight
Parameter containing:
tensor([[ 0.0018, -0.0213, -0.0421,  ...,  0.0064, -0.0055, -0.0221],
        [-0.0298, -0.0036,  0.0277,  ...,  0.0181, -0.0311, -0.0340],
        [-0.0434, -0.0196, -0.0026,  ...,  0.0229,  0.0110,  0.0345],
        ...,
        [-0.0097,  0.0290,  0.0391,  ..., -0.0091, -0.0028, -0.0373],
        [ 0.0073,  0.0256,  0.0238,  ...,  0.0004, -0.0275, -0.0347],
        [-0.0356, -0.0144, -0.0418,  ...,  0.0020, -0.0338,  0.0351]])
fc.bias
Parameter containing:
tensor([ 0.0226,  0.0271, -0.0001, -0.0043, -0.0339, -0.0216,  0.0056, -0.0067,
        -0.0024,  0.0062, -0.0425,  0.0148,  0.0340, -0.0365,  0.0084,  0.0407,
        -0.0335,  0.0133,  0.0155, -0.0303, -0.0304,  0.0291, -0.0440, -0.0278,
        -0.0324, -0.0041,  0.0068, -0.0349, -0.0380,  0.0169,  0.0203, -0.0439,
         0.0185, -0.0206,  0.0164, -0.0020,  0.0067, -0.0101, -0.0166,  0.0155,
        -0.0302,  0.0095, -0.0119,  0.0124,  0.0065,  0.0435, -0.0298, -0.0022,
        -0.0379,  0.0406, -0.0065, -0.0210,  0.0151,  0.0111, -0.0251, -0.0374,
        -0.0246, -0.0052, -0.0139, -0.0246,  0.0008,  0.0318,  0.0358,  0.0370,
         0.0277, -0.0124,  0.0391, -0.0158,  0.0052, -0.0437,  0.0076, -0.0352,
        -0.0045,  0.0270, -0.0332, -0.0177, -0.0420, -0.0116, -0.0224,  0.0410,
        -0.0063,  0.0408, -0.0062, -0.0034, -0.0084, -0.0306,  0.0077,  0.0284,
         0.0437, -0.0336, -0.0037, -0.0261, -0.0234, -0.0166,  0.0153, -0.0213,
        -0.0396,  0.0089,  0.0307,  0.0027, -0.0179, -0.0098,  0.0343,  0.0375,
        -0.0432,  0.0126, -0.0060,  0.0370, -0.0030,  0.0219,  0.0140, -0.0165,
         0.0409, -0.0134, -0.0352, -0.0265, -0.0312, -0.0205,  0.0268,  0.0429,
        -0.0260,  0.0252,  0.0016, -0.0357,  0.0052, -0.0040,  0.0173, -0.0442,
        -0.0089, -0.0116,  0.0166, -0.0381,  0.0143,  0.0032, -0.0010, -0.0092,
         0.0130, -0.0224, -0.0258,  0.0187,  0.0179,  0.0061,  0.0297,  0.0404,
         0.0218,  0.0116,  0.0345, -0.0209, -0.0256,  0.0091, -0.0302, -0.0306,
         0.0239,  0.0412,  0.0408,  0.0253, -0.0016, -0.0182, -0.0015,  0.0042,
         0.0150, -0.0024,  0.0138, -0.0164,  0.0076, -0.0181,  0.0339,  0.0179,
        -0.0289, -0.0327, -0.0399,  0.0419,  0.0386, -0.0345, -0.0321, -0.0413,
         0.0390, -0.0339,  0.0359,  0.0012,  0.0298,  0.0134, -0.0128, -0.0251,
         0.0435, -0.0231, -0.0432,  0.0385, -0.0423, -0.0137,  0.0170, -0.0412,
        -0.0413,  0.0114,  0.0428, -0.0425, -0.0089, -0.0290, -0.0112,  0.0277,
        -0.0109,  0.0083,  0.0324, -0.0163, -0.0389, -0.0206,  0.0052, -0.0091,
         0.0151,  0.0244,  0.0010, -0.0173, -0.0015, -0.0437, -0.0377, -0.0107,
        -0.0185,  0.0059, -0.0198, -0.0395, -0.0129, -0.0363,  0.0408,  0.0418,
        -0.0230, -0.0122, -0.0168,  0.0281,  0.0338,  0.0321,  0.0219, -0.0041,
         0.0307, -0.0244, -0.0007, -0.0307,  0.0387,  0.0051,  0.0417, -0.0241,
        -0.0165,  0.0186,  0.0210,  0.0373,  0.0192,  0.0415, -0.0230, -0.0091,
        -0.0324, -0.0416,  0.0304,  0.0065,  0.0398,  0.0036, -0.0232, -0.0392,
         0.0109, -0.0108, -0.0320, -0.0032, -0.0138,  0.0357, -0.0247,  0.0363,
        -0.0185, -0.0197, -0.0068, -0.0120, -0.0377, -0.0101,  0.0210, -0.0243,
         0.0269,  0.0128, -0.0142, -0.0385,  0.0185,  0.0233, -0.0051, -0.0172,
         0.0224, -0.0434,  0.0078,  0.0159, -0.0201, -0.0363, -0.0246,  0.0044,
        -0.0306, -0.0377, -0.0313,  0.0366,  0.0368, -0.0303,  0.0066,  0.0322,
        -0.0143,  0.0343,  0.0233, -0.0337, -0.0211, -0.0060, -0.0167, -0.0189,
        -0.0236,  0.0292,  0.0194,  0.0372, -0.0055,  0.0430,  0.0243, -0.0126,
         0.0208,  0.0273,  0.0145,  0.0269,  0.0020, -0.0070, -0.0102,  0.0016,
        -0.0191,  0.0397, -0.0001, -0.0044, -0.0360,  0.0095,  0.0357,  0.0089,
         0.0235, -0.0244,  0.0088,  0.0222,  0.0259,  0.0096,  0.0189,  0.0390,
         0.0401, -0.0208, -0.0358, -0.0197, -0.0248, -0.0088,  0.0085,  0.0018,
        -0.0132,  0.0289,  0.0287, -0.0400,  0.0063, -0.0054,  0.0399,  0.0123,
         0.0102, -0.0326,  0.0194, -0.0049,  0.0104, -0.0171, -0.0080,  0.0429,
        -0.0056, -0.0298,  0.0064,  0.0341, -0.0191,  0.0132, -0.0174,  0.0435,
         0.0035,  0.0093, -0.0321, -0.0366,  0.0307,  0.0088, -0.0395,  0.0357,
         0.0032, -0.0149, -0.0247,  0.0124,  0.0436,  0.0126, -0.0156,  0.0050,
         0.0109,  0.0183, -0.0404, -0.0018, -0.0104, -0.0395,  0.0390, -0.0306,
         0.0261, -0.0244, -0.0253, -0.0329,  0.0334,  0.0429, -0.0138, -0.0190,
        -0.0235, -0.0204, -0.0393,  0.0217, -0.0332, -0.0438, -0.0294, -0.0158,
        -0.0103,  0.0238, -0.0419, -0.0408,  0.0113,  0.0366,  0.0221, -0.0190,
        -0.0244,  0.0335,  0.0102, -0.0101, -0.0111, -0.0284, -0.0155, -0.0114,
         0.0137,  0.0019, -0.0006, -0.0074,  0.0137, -0.0260, -0.0037,  0.0199,
         0.0155, -0.0296,  0.0173,  0.0224,  0.0091, -0.0167,  0.0004,  0.0206,
        -0.0237, -0.0195,  0.0387, -0.0045,  0.0088,  0.0261, -0.0418, -0.0144,
        -0.0375, -0.0106, -0.0354,  0.0411, -0.0053, -0.0248, -0.0010,  0.0323,
        -0.0203,  0.0012,  0.0204, -0.0320,  0.0321,  0.0137,  0.0064, -0.0329,
         0.0051, -0.0340,  0.0171,  0.0422,  0.0266,  0.0238, -0.0164,  0.0103,
        -0.0413, -0.0355,  0.0127,  0.0207, -0.0240,  0.0398,  0.0323,  0.0217,
         0.0030,  0.0396,  0.0327, -0.0060,  0.0312, -0.0117, -0.0079,  0.0095,
         0.0423,  0.0010,  0.0018,  0.0233,  0.0434, -0.0210,  0.0049, -0.0072,
         0.0031,  0.0052, -0.0292, -0.0217, -0.0253,  0.0274, -0.0230, -0.0342,
        -0.0149,  0.0137, -0.0057,  0.0344, -0.0327,  0.0370,  0.0142, -0.0194,
         0.0109,  0.0366, -0.0046, -0.0203,  0.0088, -0.0117,  0.0263,  0.0020,
        -0.0335,  0.0387, -0.0196,  0.0386, -0.0433, -0.0012,  0.0138, -0.0383,
         0.0059, -0.0313, -0.0404, -0.0103, -0.0105, -0.0196, -0.0297,  0.0331,
         0.0372, -0.0275,  0.0007,  0.0266,  0.0006,  0.0255, -0.0035, -0.0365,
         0.0165, -0.0423,  0.0054,  0.0041,  0.0026, -0.0338,  0.0436, -0.0385,
         0.0346, -0.0295, -0.0315,  0.0096,  0.0376,  0.0166, -0.0351,  0.0077,
         0.0028,  0.0139,  0.0394,  0.0115, -0.0231, -0.0134,  0.0167,  0.0160,
        -0.0003, -0.0152,  0.0399,  0.0367, -0.0424,  0.0104, -0.0025, -0.0118,
         0.0020, -0.0115, -0.0253, -0.0290,  0.0042, -0.0025, -0.0155,  0.0101,
        -0.0170,  0.0135,  0.0283,  0.0441,  0.0294, -0.0083, -0.0428, -0.0267,
        -0.0247, -0.0344,  0.0177,  0.0173,  0.0029, -0.0369, -0.0150, -0.0193,
         0.0428, -0.0401,  0.0377,  0.0138, -0.0136, -0.0104,  0.0325,  0.0335,
        -0.0152, -0.0014, -0.0287,  0.0375, -0.0426,  0.0393,  0.0016, -0.0244,
         0.0394,  0.0212, -0.0019, -0.0024, -0.0212,  0.0249, -0.0244, -0.0144,
         0.0227,  0.0073,  0.0412, -0.0194, -0.0300,  0.0084, -0.0273,  0.0157,
        -0.0270, -0.0411,  0.0153, -0.0040, -0.0268,  0.0048,  0.0164, -0.0165,
        -0.0089,  0.0328,  0.0345, -0.0013, -0.0226, -0.0097, -0.0234,  0.0156,
         0.0236,  0.0076, -0.0349, -0.0442, -0.0293,  0.0132, -0.0420, -0.0020,
         0.0059, -0.0231, -0.0187, -0.0304,  0.0270,  0.0382,  0.0407, -0.0244,
         0.0394, -0.0204, -0.0356, -0.0287, -0.0221, -0.0330,  0.0249,  0.0247,
         0.0250, -0.0180, -0.0182, -0.0358, -0.0196, -0.0441,  0.0407, -0.0027,
         0.0177,  0.0167, -0.0258, -0.0097, -0.0210, -0.0106, -0.0331,  0.0008,
        -0.0392, -0.0378,  0.0418,  0.0093, -0.0442,  0.0419, -0.0233,  0.0076,
         0.0394,  0.0428,  0.0027, -0.0117, -0.0161,  0.0153,  0.0035,  0.0222,
         0.0375, -0.0300,  0.0200,  0.0285, -0.0120,  0.0074,  0.0054, -0.0117,
        -0.0161, -0.0088,  0.0428,  0.0103,  0.0136,  0.0376, -0.0172,  0.0086,
         0.0183,  0.0381,  0.0109, -0.0177, -0.0416, -0.0351,  0.0338, -0.0404,
         0.0324,  0.0041, -0.0172,  0.0002, -0.0415,  0.0423,  0.0172, -0.0068,
         0.0370,  0.0321,  0.0040,  0.0401,  0.0030,  0.0238,  0.0102, -0.0437,
         0.0204,  0.0432,  0.0399, -0.0186, -0.0067,  0.0205, -0.0098,  0.0009,
         0.0019,  0.0317, -0.0237,  0.0062,  0.0097, -0.0312, -0.0033,  0.0028,
        -0.0021,  0.0006, -0.0320, -0.0196,  0.0363,  0.0285,  0.0321, -0.0132,
         0.0344, -0.0232,  0.0379, -0.0166,  0.0311, -0.0174,  0.0431,  0.0006,
        -0.0066, -0.0344, -0.0076,  0.0245,  0.0286, -0.0388,  0.0114,  0.0204,
         0.0137,  0.0387,  0.0206, -0.0392, -0.0109,  0.0375,  0.0269,  0.0232,
        -0.0362,  0.0235, -0.0137,  0.0303,  0.0389, -0.0068,  0.0306,  0.0273,
         0.0264, -0.0074,  0.0315, -0.0291, -0.0027, -0.0061,  0.0188, -0.0123,
        -0.0360, -0.0266,  0.0292,  0.0248,  0.0127, -0.0251, -0.0426,  0.0066,
        -0.0005, -0.0162, -0.0236, -0.0330,  0.0339,  0.0319,  0.0135,  0.0260,
        -0.0389, -0.0375, -0.0192, -0.0079, -0.0066, -0.0261, -0.0441, -0.0042,
         0.0086,  0.0291,  0.0283,  0.0028, -0.0137, -0.0218,  0.0109, -0.0052,
        -0.0213, -0.0108, -0.0354, -0.0012,  0.0006, -0.0111, -0.0066,  0.0401,
        -0.0313,  0.0203,  0.0060,  0.0339,  0.0072,  0.0148,  0.0047,  0.0363,
        -0.0117,  0.0338,  0.0238, -0.0367,  0.0093,  0.0048,  0.0222,  0.0164,
         0.0399,  0.0005,  0.0056,  0.0023, -0.0330,  0.0241, -0.0403,  0.0047,
         0.0312, -0.0148,  0.0146,  0.0268,  0.0439, -0.0265, -0.0044,  0.0441,
         0.0395, -0.0164,  0.0117,  0.0343,  0.0355,  0.0243, -0.0091,  0.0083,
        -0.0213, -0.0196, -0.0082,  0.0070,  0.0403, -0.0407,  0.0270,  0.0375,
         0.0207,  0.0207, -0.0049,  0.0020,  0.0429, -0.0432, -0.0368, -0.0040,
         0.0035,  0.0114, -0.0345, -0.0286,  0.0232,  0.0342,  0.0437,  0.0193,
         0.0030,  0.0375,  0.0161,  0.0172, -0.0069, -0.0118, -0.0235, -0.0155,
        -0.0029, -0.0035,  0.0372, -0.0343,  0.0183, -0.0057, -0.0093, -0.0322,
        -0.0303,  0.0275, -0.0364, -0.0240, -0.0090, -0.0058, -0.0055,  0.0315,
        -0.0020,  0.0268, -0.0305, -0.0286, -0.0083,  0.0015, -0.0226,  0.0249,
        -0.0133, -0.0359,  0.0393,  0.0058, -0.0354,  0.0011,  0.0424,  0.0363,
         0.0405,  0.0006, -0.0422,  0.0363, -0.0298, -0.0319, -0.0131,  0.0021,
         0.0276, -0.0302, -0.0350, -0.0433,  0.0185,  0.0263,  0.0307, -0.0093,
         0.0377,  0.0031, -0.0115, -0.0297, -0.0327,  0.0103,  0.0179, -0.0071,
         0.0029, -0.0345, -0.0335, -0.0184,  0.0426,  0.0169,  0.0039, -0.0071,
         0.0421, -0.0185, -0.0235, -0.0288, -0.0305, -0.0199,  0.0091, -0.0162,
        -0.0269,  0.0172,  0.0330, -0.0416, -0.0347, -0.0392, -0.0148, -0.0167])
    '''
    # 输出可以进行梯度更新的权重参数名称
    # if param.requires_grad == True:
    #     print('\t', name)
    '''
	conv1.weight
	 bn1.weight
	 bn1.bias
	 layer1.0.conv1.weight
	 layer1.0.bn1.weight
	 layer1.0.bn1.bias
	 layer1.0.conv2.weight
	 layer1.0.bn2.weight
	 layer1.0.bn2.bias
	 layer1.1.conv1.weight
	 layer1.1.bn1.weight
	 layer1.1.bn1.bias
	 layer1.1.conv2.weight
	 layer1.1.bn2.weight
	 layer1.1.bn2.bias
	 layer2.0.conv1.weight
	 layer2.0.bn1.weight
	 layer2.0.bn1.bias
	 layer2.0.conv2.weight
	 layer2.0.bn2.weight
	 layer2.0.bn2.bias
	 layer2.0.downsample.0.weight
	 layer2.0.downsample.1.weight
	 layer2.0.downsample.1.bias
	 layer2.1.conv1.weight
	 layer2.1.bn1.weight
	 layer2.1.bn1.bias
	 layer2.1.conv2.weight
	 layer2.1.bn2.weight
	 layer2.1.bn2.bias
	 layer3.0.conv1.weight
	 layer3.0.bn1.weight
	 layer3.0.bn1.bias
	 layer3.0.conv2.weight
	 layer3.0.bn2.weight
	 layer3.0.bn2.bias
	 layer3.0.downsample.0.weight
	 layer3.0.downsample.1.weight
	 layer3.0.downsample.1.bias
	 layer3.1.conv1.weight
	 layer3.1.bn1.weight
	 layer3.1.bn1.bias
	 layer3.1.conv2.weight
	 layer3.1.bn2.weight
	 layer3.1.bn2.bias
	 layer4.0.conv1.weight
	 layer4.0.bn1.weight
	 layer4.0.bn1.bias
	 layer4.0.conv2.weight
	 layer4.0.bn2.weight
	 layer4.0.bn2.bias
	 layer4.0.downsample.0.weight
	 layer4.0.downsample.1.weight
	 layer4.0.downsample.1.bias
	 layer4.1.conv1.weight
	 layer4.1.bn1.weight
	 layer4.1.bn1.bias
	 layer4.1.conv2.weight
	 layer4.1.bn2.weight
	 layer4.1.bn2.bias
	 fc.weight
	 fc.bias
	'''

五、模型训练及参数配置

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # 加载预训练模型
    # resnet有18、50、101、152等层
    # model_name.resnet18只有网络结构,没有预训练参数。pretrained是否加载预训练模型,pretrained=True下载resnet18到C:/user/cache
    model_ft = model_name.resnet18(pretrained=use_pretrained)

    # 冻结网络结构中所有的参数更新
    set_parameter_requires_grad(model_ft, feature_extract)

    # 找到全连接层的输入参数:512
    # resnet18的全连接层:(fc): Linear(in_features=512, out_features=1000, bias=True)
    num_ftrs = model_ft.fc.in_features
    # num_classes自己的任务类别数,覆盖resnet18中的全连接输出,此全连接网络中的参数可以进行梯度更新
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    # # 输入大小根据自己配置来
    # input_size = 64
    # return model_ft, input_size
    return model_ft

# input_size好像没用处,删去
# model_ft, input_size = initialize_model(models, 102, feature_extract, use_pretrained=True)
model_ft = initialize_model(models, 102, feature_extract, use_pretrained=True)

#GPU还是CPU计算
model_ft = model_ft.to(device)

# 模型保存路径,名字自己起——————网络结构和权重参数
# filename为迁移学习得到的最佳训练模型路径
filename = 'D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/best_my.pt'
# filename_new为训练所有层得到的最佳训练模型路径
filename_new = 'D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/best_my_true.pt'

# 是否训练所有层
# 将model_ft的所有权重参数保存到params_to_update
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name, param in model_ft.named_parameters():
        # 此项目中只有新加的全连接层param.requires_grad为True,进行参数更新
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)
else:
    for name, param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t", name)

# 优化器设置,只训练params_to_update中的参数
# optimizer_ft = optim.Adam(params_to_update, lr=1e-2)
# 参数由迁移学习得到的参数开始进行更新,学习率小一些
optimizer_ft = optim.Adam(params_to_update, lr=1e-3)
# 学习率衰减策略:学习率每step_size个epoch衰减成原来的gamma倍
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
# 损失函数
criterion = nn.CrossEntropyLoss()

# 加载迁移学习训练好的权重参数
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

# 训练函数
# 更改filenam为filename_new,函数内部做相同更改
def train_model(model, dataloaders, criterion, optimizer,  filename_new, num_epochs=25):
    # 计算训练起始时间
    since = time.time()
    # 记录训练最好的那一次的准确率
    best_acc = 0
    # 判断模型放到CPU或者GPU
    model.to(device)

    # 训练过程中打印一堆损失和指标
    # 验证准确率
    val_acc_history = []
    # 训练准确率
    train_acc_history = []
    # 训练损失
    train_losses = []
    # 验证损失
    valid_losses = []

    # 当前学习率         optimizer.param_groups是一个字典结构
    LRs = [optimizer.param_groups[0]['lr']]
    # print(optimizer.param_groups)
    '''
    [{'params': [Parameter containing:
tensor([[-0.0065, -0.0276,  0.0032,  ..., -0.0383, -0.0091,  0.0323],
        [ 0.0141,  0.0429, -0.0220,  ...,  0.0234,  0.0301, -0.0281],
        [ 0.0064, -0.0427,  0.0039,  ...,  0.0214, -0.0171, -0.0016],
        ...,
        [-0.0355,  0.0126, -0.0099,  ..., -0.0322, -0.0201,  0.0245],
        [-0.0127,  0.0114, -0.0213,  ...,  0.0270, -0.0070, -0.0315],
        [-0.0226, -0.0235,  0.0262,  ..., -0.0109,  0.0241,  0.0084]],
       requires_grad=True), Parameter containing:
tensor([-0.0118,  0.0029, -0.0184,  0.0226,  0.0082, -0.0320, -0.0046,  0.0358,
        -0.0234,  0.0430,  0.0245,  0.0431, -0.0127, -0.0231,  0.0230,  0.0357,
        -0.0181,  0.0389,  0.0127,  0.0343,  0.0044,  0.0217, -0.0323, -0.0211,
         0.0309,  0.0416, -0.0317, -0.0248,  0.0093, -0.0324, -0.0115,  0.0181,
        -0.0190, -0.0005,  0.0418, -0.0369, -0.0144, -0.0229, -0.0295, -0.0048,
         0.0088, -0.0371, -0.0203, -0.0163,  0.0073,  0.0044, -0.0410, -0.0289,
        -0.0305, -0.0363,  0.0409,  0.0364,  0.0082,  0.0419, -0.0063, -0.0100,
         0.0008, -0.0270, -0.0163,  0.0059, -0.0100,  0.0252,  0.0183, -0.0160,
         0.0027,  0.0347, -0.0131, -0.0292, -0.0225, -0.0183,  0.0326, -0.0062,
        -0.0422, -0.0220, -0.0410, -0.0408,  0.0405, -0.0046, -0.0339,  0.0411,
        -0.0015, -0.0371, -0.0152,  0.0244, -0.0128, -0.0117, -0.0275,  0.0333,
         0.0033,  0.0276,  0.0302, -0.0367,  0.0236,  0.0409,  0.0192, -0.0411,
        -0.0224, -0.0200, -0.0321,  0.0120, -0.0427,  0.0333],
       requires_grad=True)], 'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 
       'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.01}]
    '''
    # 查看优化器参数
    # print(optimizer)
    '''
    Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        initial_lr: 0.01
        lr: 0.01
        maximize: False
        weight_decay: 0
    )
    '''

    # 最好的那次模型,后续会变的,先初始化————————复制当前的权重参数,model.state_dict()模型当前权重参数
    best_model_wts = copy.deepcopy(model.state_dict())

    # 一个个epoch来遍历
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 训练和验证
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 训练
            else:
                model.eval()  # 验证

            # 初始化损失和预测正确个数
            running_loss = 0.0
            running_corrects = 0

            # 把数据都取个遍——————dataloaders字典结构,根据phase关键字决定取哪一部分
            for inputs, labels in dataloaders[phase]:
                # to(device)数据放到你的CPU或GPU
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 梯度清零
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                # 102个类别中概率最大值的下标
                _, preds = torch.max(outputs, 1)
                # 训练阶段更新权重
                if phase == 'train':
                    loss.backward()
                    optimizer.step() #完成一次迭代

                # 累加计算总损失和总准确率
                # input格式为(batch, c, h, w),inputs.size(0)表示batch那个维度
                running_loss += loss.item() * inputs.size(0)
                # 预测结果最大的和真实值是否一致
                running_corrects += torch.sum(preds == labels.data)

            # 计算每个epoch的损失和准确率
            epoch_loss = running_loss / len(dataloaders[phase].dataset)  # 算平均
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            # 一个epoch需要多少时间
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # 得到最好那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),  # 字典里key就是各层的名字,值就是训练好的权重
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(state, filename_new)

            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                # scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)

        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()
        scheduler.step()  # 学习率衰减

    # 总体运行花了多少时间
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs

# 训练模型
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, filename_new, num_epochs=20)

训练过程展示:
训练起始准确率和损失
训练结束准确率和损失

六、验证最佳训练模型

# 加载最佳训练模型
# 更改filename
checkpoint = torch.load(filename_new)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

# 随机得到一个batch的验证数据进行测试
dataiter = iter(dataloaders['valid'])
images, labels = next(dataiter)
model_ft.eval()
if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)

# 得到概率最大的一个
_, preds_tensor = torch.max(output, 1)
# 判断数据是否在GPU,在的话数据转cpu中转ndarray类型,在的话直接转ndarray类型
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())


# 展示预测结果
def im_convert(tensor):
    """ 展示数据"""
    # 数据从cpu中克隆一份出来
    image = tensor.to("cpu").clone().detach()
    # 数据从tensor格式转为ndarray
    # numpy.squeeze(a, axis=None),用于从数组的形状中删除单维条目,其中a表示输入的数组,axis用于指定需要删除的维度。如果axis为空,则删除所有单维度的条目。
    image = image.numpy().squeeze()
    # PIL工具包
    # tensor中数据格式为c*h*w,正常数据格式为h*w*c,transpose()将0、1、2代表tensor中三个维度c、h、w,转换为1、2、0即h、w、c格式
    image = image.transpose(1, 2, 0)
    # 反标准化操作,均值u,标准差b,x = x*b+u
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    # clip()将数组中的元素值限制在给定的最小值和最大值之间,超出这个范围的值会被截断到最小值或最大值。
    image = image.clip(0, 1)
    return image


fig = plt.figure(figsize=(20, 20))
columns = 4
rows = 2

for idx in range(columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
                 color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

验证结果展示:
随机一个batch的八张验证结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值