小白学Pytorch使用(4-3):花数据集分类——自定义DataLoader

训练集和验证集数据路径

path = r’D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第六章:DataLoader自定义数据集制作\flower_data’
train_path = path + ‘/train_data’
val_path = path + ‘/valid_data’

训练集和验证集数据处理

train_dataset = FlowerDataset(root_dir=train_path, ann_file=train_file_path, transform=data_transforms[‘train’])
valid_dataset = FlowerDataset(root_dir=val_path, ann_file=val_file_path, transform=data_transforms[‘valid’])

DataLoader划分数据集

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)


### 三、数据集与标签划分验证


随机选取训练集一个batch中的一个数据,压缩1维度,转为numpy格式,进行反标准化操作,展示图像,输出图像对应标签



iter(train_loader)迭代train_loader数据,next()随机取一个batch数据

image, label = next(iter(train_loader))

print(image.shape)

torch.Size([64, 3, 64, 64])

print(label)

tensor([ 43, 72, 74, 87, 79, 27, 0, 59, 13, 63, 72, 68, 78, 87,

77, 72, 89, 31, 16, 82, 99, 82, 101, 50, 57, 69, 59, 79,

3, 50, 95, 73, 82, 2, 56, 70, 18, 87, 46, 73, 94, 90,

52, 75, 85, 98, 51, 36, 40, 97, 16, 86, 50, 57, 55, 80,

89, 28, 1, 57, 63, 16, 47, 1])

将维度为1的维度去除

sample = image[0].squeeze()

print(sample.shape)

torch.Size([3, 64, 64])

sample的维度为torch.Size([3, 64, 64]),转换为numpy格式[64, 64, 3]

sample = sample.permute((1, 2, 0)).numpy()

反标准化

sample = sample * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
sample = sample.clip(0, 1)
plt.imshow(sample)
plt.show()
print(‘label is :{}’.format(label[0].numpy()))


结果展示:  
 ![随机图片](https://img-blog.csdnimg.cn/direct/ad7cb0d569d84550a54e0d15db4cf271.png#pic_center)  
 ![随机图片对应标签](https://img-blog.csdnimg.cn/direct/40a93d0a4722487992668ec82c0ebd22.png#pic_center)


### 三、网络训练



dataloaders = {‘train’:train_loader, ‘valid’:val_loader}

打开cat_to_name.json文件,文件中有数字对应的实际类别名称

with open(‘D:/咕泡人工智能-配套资料\配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json’, ‘r’) as f:
cat_to_name = json.load(f)
# print(cat_to_name)
‘’’
{‘21’: ‘fire lily’, ‘3’: ‘canterbury bells’, ‘45’: ‘bolero deep blue’, ‘1’: ‘pink primrose’, ‘34’: ‘mexican aster’, ‘27’: ‘prince of wales feathers’,
‘7’: ‘moon orchid’, ‘16’: ‘globe-flower’, ‘25’: ‘grape hyacinth’, ‘26’: ‘corn poppy’, ‘79’: ‘toad lily’, ‘39’: ‘siam tulip’, ‘24’: ‘red ginger’,
‘67’: ‘spring crocus’, ‘35’: ‘alpine sea holly’, ‘32’: ‘garden phlox’, ‘10’: ‘globe thistle’, ‘6’: ‘tiger lily’, ‘93’: ‘ball moss’, ‘33’: ‘love in the mist’,
‘9’: ‘monkshood’, ‘102’: ‘blackberry lily’, ‘14’: ‘spear thistle’, ‘19’: ‘balloon flower’, ‘100’: ‘blanket flower’, ‘13’: ‘king protea’, ‘49’: ‘oxeye daisy’,
‘15’: ‘yellow iris’, ‘61’: ‘cautleya spicata’, ‘31’: ‘carnation’, ‘64’: ‘silverbush’, ‘68’: ‘bearded iris’, ‘63’: ‘black-eyed susan’, ‘69’: ‘windflower’,
‘62’: ‘japanese anemone’, ‘20’: ‘giant white arum lily’, ‘38’: ‘great masterwort’, ‘4’: ‘sweet pea’, ‘86’: ‘tree mallow’, ‘101’: ‘trumpet creeper’,
‘42’: ‘daffodil’, ‘22’: ‘pincushion flower’, ‘2’: ‘hard-leaved pocket orchid’, ‘54’: ‘sunflower’, ‘66’: ‘osteospermum’, ‘70’: ‘tree poppy’, ‘85’: ‘desert-rose’,
‘99’: ‘bromelia’, ‘87’: ‘magnolia’, ‘5’: ‘english marigold’, ‘92’: ‘bee balm’, ‘28’: ‘stemless gentian’, ‘97’: ‘mallow’, ‘57’: ‘gaura’, ‘40’: ‘lenten rose’,
‘47’: ‘marigold’, ‘59’: ‘orange dahlia’, ‘48’: ‘buttercup’, ‘55’: ‘pelargonium’, ‘36’: ‘ruby-lipped cattleya’, ‘91’: ‘hippeastrum’, ‘29’: ‘artichoke’, ‘71’: ‘gazania’,
‘90’: ‘canna lily’, ‘18’: ‘peruvian lily’, ‘98’: ‘mexican petunia’, ‘8’: ‘bird of paradise’, ‘30’: ‘sweet william’, ‘17’: ‘purple coneflower’, ‘52’: ‘wild pansy’,
‘84’: ‘columbine’, ‘12’: “colt’s foot”, ‘11’: ‘snapdragon’, ‘96’: ‘camellia’, ‘23’: ‘fritillary’, ‘50’: ‘common dandelion’, ‘44’: ‘poinsettia’, ‘53’: ‘primula’,
‘72’: ‘azalea’, ‘65’: ‘californian poppy’, ‘80’: ‘anthurium’, ‘76’: ‘morning glory’, ‘37’: ‘cape flower’, ‘56’: ‘bishop of llandaff’, ‘60’: ‘pink-yellow dahlia’,
‘82’: ‘clematis’, ‘58’: ‘geranium’, ‘75’: ‘thorn apple’, ‘41’: ‘barbeton daisy’, ‘95’: ‘bougainvillea’, ‘43’: ‘sword lily’, ‘83’: ‘hibiscus’, ‘78’: ‘lotus lotus’,
‘88’: ‘cyclamen’, ‘94’: ‘foxglove’, ‘81’: ‘frangipani’, ‘74’: ‘rose’, ‘89’: ‘watercress’, ‘73’: ‘water lily’, ‘46’: ‘wallflower’, ‘77’: ‘passion flower’,
‘51’: ‘petunia’}
‘’’

迁移学习:使用前人的网络结构和模型做训练

数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层

数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练

数据量较大时,整个模型不进行冻结,全部更新训练

加载models中提供的模型,并且直接用训练好的权重当作初始化参数

#可选网络结构比较多 [‘resnet’, ‘alexnet’, ‘vgg’, ‘squeezenet’, ‘densenet’, ‘inception’],resnet网络效果较好
model_name = ‘resnet’

是否用人家训练好的特征来做

此项目冻结输出层前面所有部分,不进行训练更新

feature_extract = True

是否用GPU训练——————固定写法

train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print(‘CUDA is not available. Training on CPU …’)
else:
print(‘CUDA is available! Training on GPU …’)
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

使用18层的resnet网络结构,18层的能快点,条件好点的也可以选152

model_ft = models.resnet18()

resnet18网络结构

print(model_ft)

‘’’
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))

resnet18输出层为1000个类别,此次项目预测类别为102,后续修改输出层

(fc): Linear(in_features=512, out_features=1000, bias=True)
)
‘’’

自定义函数判断模型参数是否要进行更新

def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
# 对18层resnet网络结构中的权重参数进行遍历
for param in model.parameters():
# requires_grad反向传播时是否更新参数,True进行更新
param.requires_grad = True

for param in model_ft.parameters():

param.requires_grad = False

for name, param in model_ft.named_parameters():

# 输出所有权重参数名称及参数值

print(name)

print(param)

# 如输出层权重参数:

‘’’
fc.weight
Parameter containing:
tensor([[ 0.0018, -0.0213, -0.0421, …, 0.0064, -0.0055, -0.0221],
[-0.0298, -0.0036, 0.0277, …, 0.0181, -0.0311, -0.0340],
[-0.0434, -0.0196, -0.0026, …, 0.0229, 0.0110, 0.0345],
…,
[-0.0097, 0.0290, 0.0391, …, -0.0091, -0.0028, -0.0373],
[ 0.0073, 0.0256, 0.0238, …, 0.0004, -0.0275, -0.0347],
[-0.0356, -0.0144, -0.0418, …, 0.0020, -0.0338, 0.0351]])
fc.bias
Parameter containing:
tensor([ 0.0226, 0.0271, -0.0001, -0.0043, -0.0339, -0.0216, 0.0056, -0.0067,
-0.0024, 0.0062, -0.0425, 0.0148, 0.0340, -0.0365, 0.0084, 0.0407,
-0.0335, 0.0133, 0.0155, -0.0303, -0.0304, 0.0291, -0.0440, -0.0278,
-0.0324, -0.0041, 0.0068, -0.0349, -0.0380, 0.0169, 0.0203, -0.0439,
0.0185, -0.0206, 0.0164, -0.0020, 0.0067, -0.0101, -0.0166, 0.0155,
-0.0302, 0.0095, -0.0119, 0.0124, 0.0065, 0.0435, -0.0298, -0.0022,
-0.0379, 0.0406, -0.0065, -0.0210, 0.0151, 0.0111, -0.0251, -0.0374,
-0.0246, -0.0052, -0.0139, -0.0246, 0.0008, 0.0318, 0.0358, 0.0370,
0.0277, -0.0124, 0.0391, -0.0158, 0.0052, -0.0437, 0.0076, -0.0352,
-0.0045, 0.0270, -0.0332, -0.0177, -0.0420, -0.0116, -0.0224, 0.0410,
-0.0063, 0.0408, -0.0062, -0.0034, -0.0084, -0.0306, 0.0077, 0.0284,
0.0437, -0.0336, -0.0037, -0.0261, -0.0234, -0.0166, 0.0153, -0.0213,
-0.0396, 0.0089, 0.0307, 0.0027, -0.0179, -0.0098, 0.0343, 0.0375,
-0.0432, 0.0126, -0.0060, 0.0370, -0.0030, 0.0219, 0.0140, -0.0165,
0.0409, -0.0134, -0.0352, -0.0265, -0.0312, -0.0205, 0.0268, 0.0429,
-0.0260, 0.0252, 0.0016, -0.0357, 0.0052, -0.0040, 0.0173, -0.0442,
-0.0089, -0.0116, 0.0166, -0.0381, 0.0143, 0.0032, -0.0010, -0.0092,
0.0130, -0.0224, -0.0258, 0.0187, 0.0179, 0.0061, 0.0297, 0.0404,
0.0218, 0.0116, 0.0345, -0.0209, -0.0256, 0.0091, -0.0302, -0.0306,
0.0239, 0.0412, 0.0408, 0.0253, -0.0016, -0.0182, -0.0015, 0.0042,
0.0150, -0.0024, 0.0138, -0.0164, 0.0076, -0.0181, 0.0339, 0.0179,
-0.0289, -0.0327, -0.0399, 0.0419, 0.0386, -0.0345, -0.0321, -0.0413,
0.0390, -0.0339, 0.0359, 0.0012, 0.0298, 0.0134, -0.0128, -0.0251,
0.0435, -0.0231, -0.0432, 0.0385, -0.0423, -0.0137, 0.0170, -0.0412,
-0.0413, 0.0114, 0.0428, -0.0425, -0.0089, -0.0290, -0.0112, 0.0277,
-0.0109, 0.0083, 0.0324, -0.0163, -0.0389, -0.0206, 0.0052, -0.0091,
0.0151, 0.0244, 0.0010, -0.0173, -0.0015, -0.0437, -0.0377, -0.0107,
-0.0185, 0.0059, -0.0198, -0.0395, -0.0129, -0.0363, 0.0408, 0.0418,
-0.0230, -0.0122, -0.0168, 0.0281, 0.0338, 0.0321, 0.0219, -0.0041,
0.0307, -0.0244, -0.0007, -0.0307, 0.0387, 0.0051, 0.0417, -0.0241,
-0.0165, 0.0186, 0.0210, 0.0373, 0.0192, 0.0415, -0.0230, -0.0091,
-0.0324, -0.0416, 0.0304, 0.0065, 0.0398, 0.0036, -0.0232, -0.0392,
0.0109, -0.0108, -0.0320, -0.0032, -0.0138, 0.0357, -0.0247, 0.0363,
-0.0185, -0.0197, -0.0068, -0.0120, -0.0377, -0.0101, 0.0210, -0.0243,
0.0269, 0.0128, -0.0142, -0.0385, 0.0185, 0.0233, -0.0051, -0.0172,
0.0224, -0.0434, 0.0078, 0.0159, -0.0201, -0.0363, -0.0246, 0.0044,
-0.0306, -0.0377, -0.0313, 0.0366, 0.0368, -0.0303, 0.0066, 0.0322,
-0.0143, 0.0343, 0.0233, -0.0337, -0.0211, -0.0060, -0.0167, -0.0189,
-0.0236, 0.0292, 0.0194, 0.0372, -0.0055, 0.0430, 0.0243, -0.0126,
0.0208, 0.0273, 0.0145, 0.0269, 0.0020, -0.0070, -0.0102, 0.0016,
-0.0191, 0.0397, -0.0001, -0.0044, -0.0360, 0.0095, 0.0357, 0.0089,
0.0235, -0.0244, 0.0088, 0.0222, 0.0259, 0.0096, 0.0189, 0.0390,
0.0401, -0.0208, -0.0358, -0.0197, -0.0248, -0.0088, 0.0085, 0.0018,
-0.0132, 0.0289, 0.0287, -0.0400, 0.0063, -0.0054, 0.0399, 0.0123,
0.0102, -0.0326, 0.0194, -0.0049, 0.0104, -0.0171, -0.0080, 0.0429,
-0.0056, -0.0298, 0.0064, 0.0341, -0.0191, 0.0132, -0.0174, 0.0435,
0.0035, 0.0093, -0.0321, -0.0366, 0.0307, 0.0088, -0.0395, 0.0357,
0.0032, -0.0149, -0.0247, 0.0124, 0.0436, 0.0126, -0.0156, 0.0050,
0.0109, 0.0183, -0.0404, -0.0018, -0.0104, -0.0395, 0.0390, -0.0306,
0.0261, -0.0244, -0.0253, -0.0329, 0.0334, 0.0429, -0.0138, -0.0190,
-0.0235, -0.0204, -0.0393, 0.0217, -0.0332, -0.0438, -0.0294, -0.0158,
-0.0103, 0.0238, -0.0419, -0.0408, 0.0113, 0.0366, 0.0221, -0.0190,
-0.0244, 0.0335, 0.0102, -0.0101, -0.0111, -0.0284, -0.0155, -0.0114,
0.0137, 0.0019, -0.0006, -0.0074, 0.0137, -0.0260, -0.0037, 0.0199,
0.0155, -0.0296, 0.0173, 0.0224, 0.0091, -0.0167, 0.0004, 0.0206,
-0.0237, -0.0195, 0.0387, -0.0045, 0.0088, 0.0261, -0.0418, -0.0144,
-0.0375, -0.0106, -0.0354, 0.0411, -0.0053, -0.0248, -0.0010, 0.0323,
-0.0203, 0.0012, 0.0204, -0.0320, 0.0321, 0.0137, 0.0064, -0.0329,
0.0051, -0.0340, 0.0171, 0.0422, 0.0266, 0.0238, -0.0164, 0.0103,
-0.0413, -0.0355, 0.0127, 0.0207, -0.0240, 0.0398, 0.0323, 0.0217,
0.0030, 0.0396, 0.0327, -0.0060, 0.0312, -0.0117, -0.0079, 0.0095,
0.0423, 0.0010, 0.0018, 0.0233, 0.0434, -0.0210, 0.0049, -0.0072,
0.0031, 0.0052, -0.0292, -0.0217, -0.0253, 0.0274, -0.0230, -0.0342,
-0.0149, 0.0137, -0.0057, 0.0344, -0.0327, 0.0370, 0.0142, -0.0194,
0.0109, 0.0366, -0.0046, -0.0203, 0.0088, -0.0117, 0.0263, 0.0020,
-0.0335, 0.0387, -0.0196, 0.0386, -0.0433, -0.0012, 0.0138, -0.0383,
0.0059, -0.0313, -0.0404, -0.0103, -0.0105, -0.0196, -0.0297, 0.0331,
0.0372, -0.0275, 0.0007, 0.0266, 0.0006, 0.0255, -0.0035, -0.0365,
0.0165, -0.0423, 0.0054, 0.0041, 0.0026, -0.0338, 0.0436, -0.0385,
0.0346, -0.0295, -0.0315, 0.0096, 0.0376, 0.0166, -0.0351, 0.0077,
0.0028, 0.0139, 0.0394, 0.0115, -0.0231, -0.0134, 0.0167, 0.0160,
-0.0003, -0.0152, 0.0399, 0.0367, -0.0424, 0.0104, -0.0025, -0.0118,
0.0020, -0.0115, -0.0253, -0.0290, 0.0042, -0.0025, -0.0155, 0.0101,
-0.0170, 0.0135, 0.0283, 0.0441, 0.0294, -0.0083, -0.0428, -0.0267,
-0.0247, -0.0344, 0.0177, 0.0173, 0.0029, -0.0369, -0.0150, -0.0193,
0.0428, -0.0401, 0.0377, 0.0138, -0.0136, -0.0104, 0.0325, 0.0335,
-0.0152, -0.0014, -0.0287, 0.0375, -0.0426, 0.0393, 0.0016, -0.0244,
0.0394, 0.0212, -0.0019, -0.0024, -0.0212, 0.0249, -0.0244, -0.0144,
0.0227, 0.0073, 0.0412, -0.0194, -0.0300, 0.0084, -0.0273, 0.0157,
-0.0270, -0.0411, 0.0153, -0.0040, -0.0268, 0.0048, 0.0164, -0.0165,
-0.0089, 0.0328, 0.0345, -0.0013, -0.0226, -0.0097, -0.0234, 0.0156,
0.0236, 0.0076, -0.0349, -0.0442, -0.0293, 0.0132, -0.0420, -0.0020,
0.0059, -0.0231, -0.0187, -0.0304, 0.0270, 0.0382, 0.0407, -0.0244,
0.0394, -0.0204, -0.0356, -0.0287, -0.0221, -0.0330, 0.0249, 0.0247,
0.0250, -0.0180, -0.0182, -0.0358, -0.0196, -0.0441, 0.0407, -0.0027,
0.0177, 0.0167, -0.0258, -0.0097, -0.0210, -0.0106, -0.0331, 0.0008,
-0.0392, -0.0378, 0.0418, 0.0093, -0.0442, 0.0419, -0.0233, 0.0076,
0.0394, 0.0428, 0.0027, -0.0117, -0.0161, 0.0153, 0.0035, 0.0222,
0.0375, -0.0300, 0.0200, 0.0285, -0.0120, 0.0074, 0.0054, -0.0117,
-0.0161, -0.0088, 0.0428, 0.0103, 0.0136, 0.0376, -0.0172, 0.0086,
0.0183, 0.0381, 0.0109, -0.0177, -0.0416, -0.0351, 0.0338, -0.0404,
0.0324, 0.0041, -0.0172, 0.0002, -0.0415, 0.0423, 0.0172, -0.0068,
0.0370, 0.0321, 0.0040, 0.0401, 0.0030, 0.0238, 0.0102, -0.0437,
0.0204, 0.0432, 0.0399, -0.0186, -0.0067, 0.0205, -0.0098, 0.0009,
0.0019, 0.0317, -0.0237, 0.0062, 0.0097, -0.0312, -0.0033, 0.0028,
-0.0021, 0.0006, -0.0320, -0.0196, 0.0363, 0.0285, 0.0321, -0.0132,
0.0344, -0.0232, 0.0379, -0.0166, 0.0311, -0.0174, 0.0431, 0.0006,
-0.0066, -0.0344, -0.0076, 0.0245, 0.0286, -0.0388, 0.0114, 0.0204,
0.0137, 0.0387, 0.0206, -0.0392, -0.0109, 0.0375, 0.0269, 0.0232,
-0.0362, 0.0235, -0.0137, 0.0303, 0.0389, -0.0068, 0.0306, 0.0273,
0.0264, -0.0074, 0.0315, -0.0291, -0.0027, -0.0061, 0.0188, -0.0123,
-0.0360, -0.0266, 0.0292, 0.0248, 0.0127, -0.0251, -0.0426, 0.0066,
-0.0005, -0.0162, -0.0236, -0.0330, 0.0339, 0.0319, 0.0135, 0.0260,
-0.0389, -0.0375, -0.0192, -0.0079, -0.0066, -0.0261, -0.0441, -0.0042,
0.0086, 0.0291, 0.0283, 0.0028, -0.0137, -0.0218, 0.0109, -0.0052,
-0.0213, -0.0108, -0.0354, -0.0012, 0.0006, -0.0111, -0.0066, 0.0401,
-0.0313, 0.0203, 0.0060, 0.0339, 0.0072, 0.0148, 0.0047, 0.0363,
-0.0117, 0.0338, 0.0238, -0.0367, 0.0093, 0.0048, 0.0222, 0.0164,
0.0399, 0.0005, 0.0056, 0.0023, -0.0330, 0.0241, -0.0403, 0.0047,
0.0312, -0.0148, 0.0146, 0.0268, 0.0439, -0.0265, -0.0044, 0.0441,
0.0395, -0.0164, 0.0117, 0.0343, 0.0355, 0.0243, -0.0091, 0.0083,
-0.0213, -0.0196, -0.0082, 0.0070, 0.0403, -0.0407, 0.0270, 0.0375,
0.0207, 0.0207, -0.0049, 0.0020, 0.0429, -0.0432, -0.0368, -0.0040,
0.0035, 0.0114, -0.0345, -0.0286, 0.0232, 0.0342, 0.0437, 0.0193,
0.0030, 0.0375, 0.0161, 0.0172, -0.0069, -0.0118, -0.0235, -0.0155,
-0.0029, -0.0035, 0.0372, -0.0343, 0.0183, -0.0057, -0.0093, -0.0322,
-0.0303, 0.0275, -0.0364, -0.0240, -0.0090, -0.0058, -0.0055, 0.0315,
-0.0020, 0.0268, -0.0305, -0.0286, -0.0083, 0.0015, -0.0226, 0.0249,
-0.0133, -0.0359, 0.0393, 0.0058, -0.0354, 0.0011, 0.0424, 0.0363,
0.0405, 0.0006, -0.0422, 0.0363, -0.0298, -0.0319, -0.0131, 0.0021,
0.0276, -0.0302, -0.0350, -0.0433, 0.0185, 0.0263, 0.0307, -0.0093,
0.0377, 0.0031, -0.0115, -0.0297, -0.0327, 0.0103, 0.0179, -0.0071,
0.0029, -0.0345, -0.0335, -0.0184, 0.0426, 0.0169, 0.0039, -0.0071,
0.0421, -0.0185, -0.0235, -0.0288, -0.0305, -0.0199, 0.0091, -0.0162,
-0.0269, 0.0172, 0.0330, -0.0416, -0.0347, -0.0392, -0.0148, -0.0167])
‘’’
# 输出可以进行梯度更新的权重参数名称——————————因为前面修改了所有的参数更新为False,此判断无输出
# if param.requires_grad == True:
# print(‘\t’, name)

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
# 加载预训练模型
# resnet有18、50、101、152等层
# model_name.resnet18只有网络结构,没有预训练参数。pretrained预训练模型,pretrained=True下载resnet18到C:/user/cache
model_ft = model_name.resnet18(pretrained=use_pretrained)

# 冻结网络结构中所有的参数更新
set_parameter_requires_grad(model_ft, feature_extract)

# 找到全连接层的输入参数:512
# resnet18的全连接层:(fc): Linear(in\_features=512, out\_features=1000, bias=True)
num_ftrs = model_ft.fc.in_features
# num\_classes自己的任务类别数,覆盖resnet18中的全连接输出,此全连接网络中的参数可以进行梯度更新
model_ft.fc = nn.Linear(num_ftrs, num_classes)
# # 输入大小根据自己配置来
# input\_size = 64
# return model\_ft, input\_size
return model_ft

model_ft, input_size = initialize_model(models, 102, feature_extract, use_pretrained=True)

model_ft = initialize_model(models, 102, feature_extract, use_pretrained=True)

#GPU还是CPU计算
model_ft = model_ft.to(device)

模型保存,名字自己起——————网络结构和权重参数

filename = ‘D:/咕泡人工智能-配套资料/配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/best_my_false.pt’
filename_new = path + ‘/best_my_true.pt’

是否训练所有层

将model_ft的所有权重参数保存到params_to_update

params_to_update = model_ft.parameters()
print(“Params to learn:”)
if feature_extract:
params_to_update = []
for name, param in model_ft.named_parameters():
# 此项目中只有新加的全连接层param.requires_grad为True
if param.requires_grad == True:
params_to_update.append(param)
print(“\t”, name)
else:
for name, param in model_ft.named_parameters():
if param.requires_grad == True:
print(“\t”, name)

优化器设置,只训练params_to_update中的参数

optimizer_ft = optim.Adam(params_to_update, lr=1e-3)

学习率衰减策略:学习率每step_size个epoch衰减成原来的gamma倍

scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

损失函数

criterion = nn.CrossEntropyLoss()

加载之前训练好的权重参数

checkpoint = torch.load(filename)
best_acc = checkpoint[‘best_acc’]
model_ft.load_state_dict(checkpoint[‘state_dict’])

训练函数

def train_model(model, dataloaders, criterion, optimizer, filename_new, num_epochs=25):
# 计算训练起始时间
since = time.time()
# 记录训练最好的那一次的准确率
best_acc = 0
# 判断模型放到CPU或者GPU
model.to(device)

# 训练过程中打印一堆损失和指标
# 验证准确率
val_acc_history = []
# 训练准确率
train_acc_history = []
# 训练损失
train_losses = []
# 验证损失
valid_losses = []

# 当前学习率 optimizer.param\_groups是一个字典结构
LRs = [optimizer.param_groups[0]['lr']]
# print(optimizer.param\_groups)
'''

[{‘params’: [Parameter containing:
tensor([[-0.0065, -0.0276, 0.0032, …, -0.0383, -0.0091, 0.0323],
[ 0.0141, 0.0429, -0.0220, …, 0.0234, 0.0301, -0.0281],
[ 0.0064, -0.0427, 0.0039, …, 0.0214, -0.0171, -0.0016],
…,
[-0.0355, 0.0126, -0.0099, …, -0.0322, -0.0201, 0.0245],
[-0.0127, 0.0114, -0.0213, …, 0.0270, -0.0070, -0.0315],
[-0.0226, -0.0235, 0.0262, …, -0.0109, 0.0241, 0.0084]],
requires_grad=True), Parameter containing:
tensor([-0.0118, 0.0029, -0.0184, 0.0226, 0.0082, -0.0320, -0.0046, 0.0358,
-0.0234, 0.0430, 0.0245, 0.0431, -0.0127, -0.0231, 0.0230, 0.0357,
-0.0181, 0.0389, 0.0127, 0.0343, 0.0044, 0.0217, -0.0323, -0.0211,
0.0309, 0.0416, -0.0317, -0.0248, 0.0093, -0.0324, -0.0115, 0.0181,
-0.0190, -0.0005, 0.0418, -0.0369, -0.0144, -0.0229, -0.0295, -0.0048,
0.0088, -0.0371, -0.0203, -0.0163, 0.0073, 0.0044, -0.0410, -0.0289,
-0.0305, -0.0363, 0.0409, 0.0364, 0.0082, 0.0419, -0.0063, -0.0100,
0.0008, -0.0270, -0.0163, 0.0059, -0.0100, 0.0252, 0.0183, -0.0160,
0.0027, 0.0347, -0.0131, -0.0292, -0.0225, -0.0183, 0.0326, -0.0062,
-0.0422, -0.0220, -0.0410, -0.0408, 0.0405, -0.0046, -0.0339, 0.0411,
-0.0015, -0.0371, -0.0152, 0.0244, -0.0128, -0.0117, -0.0275, 0.0333,
0.0033, 0.0276, 0.0302, -0.0367, 0.0236, 0.0409, 0.0192, -0.0411,
-0.0224, -0.0200, -0.0321, 0.0120, -0.0427, 0.0333],
requires_grad=True)], ‘lr’: 0.01, ‘betas’: (0.9, 0.999), ‘eps’: 1e-08, ‘weight_decay’: 0, ‘amsgrad’: False,
‘maximize’: False, ‘foreach’: None, ‘capturable’: False, ‘differentiable’: False, ‘fused’: None, ‘initial_lr’: 0.01}]
‘’’
# print(optimizer)
‘’’
Adam (
Parameter Group 0
amsgrad: False
betas: (0.9, 0.999)
capturable: False
differentiable: False
eps: 1e-08
foreach: None
fused: None
initial_lr: 0.01
lr: 0.01
maximize: False
weight_decay: 0
)
‘’’

# 最好的那次模型,后续会变的,先初始化————————复制当前的权重参数,model.state\_dict()模型当前权重参数
best_model_wts = copy.deepcopy(model.state_dict())

# 一个个epoch来遍历
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' \* 10)

    # 训练和验证
    for phase in ['train', 'valid']:
        if phase == 'train':
            model.train()  # 训练
        else:
            model.eval()  # 验证

        # 初始化损失和预测正确个数
        running_loss = 0.0
        running_corrects = 0

        # 把数据都取个遍——————dataloaders字典结构,根据phase关键字决定取哪一部分
        for inputs, labels in dataloaders[phase]:
            # to(device)数据放到你的CPU或GPU
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 梯度清零
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # 102个类别中概率最大值的下标
            _, preds = torch.max(outputs, 1)
            # 训练阶段更新权重
            if phase == 'train':
                loss.backward()
                optimizer.step() #完成一次迭代

            # 累加计算总损失和总准确率
            # input格式为(batch, c, h, w),inputs.size(0)表示batch那个维度
            running_loss += loss.item() \* inputs.size(0)
            # 预测结果最大的和真实值是否一致

自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。

深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!

因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。

img

img

img

img

img

img

既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!

由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新

如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)

外链图片转存中…(img-0nsw5crF-1713782256313)]

[外链图片转存中…(img-KaHvvB9n-1713782256314)]

[外链图片转存中…(img-XLPw3x0g-1713782256315)]

[外链图片转存中…(img-e8jfliNE-1713782256316)]

img

img

既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!

由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新

如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值