训练集和验证集数据路径
path = r’D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第六章:DataLoader自定义数据集制作\flower_data’
train_path = path + ‘/train_data’
val_path = path + ‘/valid_data’
训练集和验证集数据处理
train_dataset = FlowerDataset(root_dir=train_path, ann_file=train_file_path, transform=data_transforms[‘train’])
valid_dataset = FlowerDataset(root_dir=val_path, ann_file=val_file_path, transform=data_transforms[‘valid’])
DataLoader划分数据集
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)
### 三、数据集与标签划分验证
随机选取训练集一个batch中的一个数据,压缩1维度,转为numpy格式,进行反标准化操作,展示图像,输出图像对应标签
iter(train_loader)迭代train_loader数据,next()随机取一个batch数据
image, label = next(iter(train_loader))
print(image.shape)
torch.Size([64, 3, 64, 64])
print(label)
tensor([ 43, 72, 74, 87, 79, 27, 0, 59, 13, 63, 72, 68, 78, 87,
77, 72, 89, 31, 16, 82, 99, 82, 101, 50, 57, 69, 59, 79,
3, 50, 95, 73, 82, 2, 56, 70, 18, 87, 46, 73, 94, 90,
52, 75, 85, 98, 51, 36, 40, 97, 16, 86, 50, 57, 55, 80,
89, 28, 1, 57, 63, 16, 47, 1])
将维度为1的维度去除
sample = image[0].squeeze()
print(sample.shape)
torch.Size([3, 64, 64])
sample的维度为torch.Size([3, 64, 64]),转换为numpy格式[64, 64, 3]
sample = sample.permute((1, 2, 0)).numpy()
反标准化
sample = sample * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
sample = sample.clip(0, 1)
plt.imshow(sample)
plt.show()
print(‘label is :{}’.format(label[0].numpy()))
结果展示:
![随机图片](https://img-blog.csdnimg.cn/direct/ad7cb0d569d84550a54e0d15db4cf271.png#pic_center)
![随机图片对应标签](https://img-blog.csdnimg.cn/direct/40a93d0a4722487992668ec82c0ebd22.png#pic_center)
### 三、网络训练
dataloaders = {‘train’:train_loader, ‘valid’:val_loader}
打开cat_to_name.json文件,文件中有数字对应的实际类别名称
with open(‘D:/咕泡人工智能-配套资料\配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json’, ‘r’) as f:
cat_to_name = json.load(f)
# print(cat_to_name)
‘’’
{‘21’: ‘fire lily’, ‘3’: ‘canterbury bells’, ‘45’: ‘bolero deep blue’, ‘1’: ‘pink primrose’, ‘34’: ‘mexican aster’, ‘27’: ‘prince of wales feathers’,
‘7’: ‘moon orchid’, ‘16’: ‘globe-flower’, ‘25’: ‘grape hyacinth’, ‘26’: ‘corn poppy’, ‘79’: ‘toad lily’, ‘39’: ‘siam tulip’, ‘24’: ‘red ginger’,
‘67’: ‘spring crocus’, ‘35’: ‘alpine sea holly’, ‘32’: ‘garden phlox’, ‘10’: ‘globe thistle’, ‘6’: ‘tiger lily’, ‘93’: ‘ball moss’, ‘33’: ‘love in the mist’,
‘9’: ‘monkshood’, ‘102’: ‘blackberry lily’, ‘14’: ‘spear thistle’, ‘19’: ‘balloon flower’, ‘100’: ‘blanket flower’, ‘13’: ‘king protea’, ‘49’: ‘oxeye daisy’,
‘15’: ‘yellow iris’, ‘61’: ‘cautleya spicata’, ‘31’: ‘carnation’, ‘64’: ‘silverbush’, ‘68’: ‘bearded iris’, ‘63’: ‘black-eyed susan’, ‘69’: ‘windflower’,
‘62’: ‘japanese anemone’, ‘20’: ‘giant white arum lily’, ‘38’: ‘great masterwort’, ‘4’: ‘sweet pea’, ‘86’: ‘tree mallow’, ‘101’: ‘trumpet creeper’,
‘42’: ‘daffodil’, ‘22’: ‘pincushion flower’, ‘2’: ‘hard-leaved pocket orchid’, ‘54’: ‘sunflower’, ‘66’: ‘osteospermum’, ‘70’: ‘tree poppy’, ‘85’: ‘desert-rose’,
‘99’: ‘bromelia’, ‘87’: ‘magnolia’, ‘5’: ‘english marigold’, ‘92’: ‘bee balm’, ‘28’: ‘stemless gentian’, ‘97’: ‘mallow’, ‘57’: ‘gaura’, ‘40’: ‘lenten rose’,
‘47’: ‘marigold’, ‘59’: ‘orange dahlia’, ‘48’: ‘buttercup’, ‘55’: ‘pelargonium’, ‘36’: ‘ruby-lipped cattleya’, ‘91’: ‘hippeastrum’, ‘29’: ‘artichoke’, ‘71’: ‘gazania’,
‘90’: ‘canna lily’, ‘18’: ‘peruvian lily’, ‘98’: ‘mexican petunia’, ‘8’: ‘bird of paradise’, ‘30’: ‘sweet william’, ‘17’: ‘purple coneflower’, ‘52’: ‘wild pansy’,
‘84’: ‘columbine’, ‘12’: “colt’s foot”, ‘11’: ‘snapdragon’, ‘96’: ‘camellia’, ‘23’: ‘fritillary’, ‘50’: ‘common dandelion’, ‘44’: ‘poinsettia’, ‘53’: ‘primula’,
‘72’: ‘azalea’, ‘65’: ‘californian poppy’, ‘80’: ‘anthurium’, ‘76’: ‘morning glory’, ‘37’: ‘cape flower’, ‘56’: ‘bishop of llandaff’, ‘60’: ‘pink-yellow dahlia’,
‘82’: ‘clematis’, ‘58’: ‘geranium’, ‘75’: ‘thorn apple’, ‘41’: ‘barbeton daisy’, ‘95’: ‘bougainvillea’, ‘43’: ‘sword lily’, ‘83’: ‘hibiscus’, ‘78’: ‘lotus lotus’,
‘88’: ‘cyclamen’, ‘94’: ‘foxglove’, ‘81’: ‘frangipani’, ‘74’: ‘rose’, ‘89’: ‘watercress’, ‘73’: ‘water lily’, ‘46’: ‘wallflower’, ‘77’: ‘passion flower’,
‘51’: ‘petunia’}
‘’’
迁移学习:使用前人的网络结构和模型做训练
数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层
数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练
数据量较大时,整个模型不进行冻结,全部更新训练
加载models中提供的模型,并且直接用训练好的权重当作初始化参数
#可选网络结构比较多 [‘resnet’, ‘alexnet’, ‘vgg’, ‘squeezenet’, ‘densenet’, ‘inception’],resnet网络效果较好
model_name = ‘resnet’
是否用人家训练好的特征来做
此项目冻结输出层前面所有部分,不进行训练更新
feature_extract = True
是否用GPU训练——————固定写法
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print(‘CUDA is not available. Training on CPU …’)
else:
print(‘CUDA is available! Training on GPU …’)
device = torch.