小白学Pytorch使用(4-3):花数据集分类——自定义DataLoader

训练集和验证集数据路径

path = r’D:\咕泡人工智能-配套资料\配套资料\4.第四章 深度学习核⼼框架PyTorch\第六章:DataLoader自定义数据集制作\flower_data’
train_path = path + ‘/train_data’
val_path = path + ‘/valid_data’

训练集和验证集数据处理

train_dataset = FlowerDataset(root_dir=train_path, ann_file=train_file_path, transform=data_transforms[‘train’])
valid_dataset = FlowerDataset(root_dir=val_path, ann_file=val_file_path, transform=data_transforms[‘valid’])

DataLoader划分数据集

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)


### 三、数据集与标签划分验证


随机选取训练集一个batch中的一个数据,压缩1维度,转为numpy格式,进行反标准化操作,展示图像,输出图像对应标签



iter(train_loader)迭代train_loader数据,next()随机取一个batch数据

image, label = next(iter(train_loader))

print(image.shape)

torch.Size([64, 3, 64, 64])

print(label)

tensor([ 43, 72, 74, 87, 79, 27, 0, 59, 13, 63, 72, 68, 78, 87,

77, 72, 89, 31, 16, 82, 99, 82, 101, 50, 57, 69, 59, 79,

3, 50, 95, 73, 82, 2, 56, 70, 18, 87, 46, 73, 94, 90,

52, 75, 85, 98, 51, 36, 40, 97, 16, 86, 50, 57, 55, 80,

89, 28, 1, 57, 63, 16, 47, 1])

将维度为1的维度去除

sample = image[0].squeeze()

print(sample.shape)

torch.Size([3, 64, 64])

sample的维度为torch.Size([3, 64, 64]),转换为numpy格式[64, 64, 3]

sample = sample.permute((1, 2, 0)).numpy()

反标准化

sample = sample * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
sample = sample.clip(0, 1)
plt.imshow(sample)
plt.show()
print(‘label is :{}’.format(label[0].numpy()))


结果展示:  
 ![随机图片](https://img-blog.csdnimg.cn/direct/ad7cb0d569d84550a54e0d15db4cf271.png#pic_center)  
 ![随机图片对应标签](https://img-blog.csdnimg.cn/direct/40a93d0a4722487992668ec82c0ebd22.png#pic_center)


### 三、网络训练



dataloaders = {‘train’:train_loader, ‘valid’:val_loader}

打开cat_to_name.json文件,文件中有数字对应的实际类别名称

with open(‘D:/咕泡人工智能-配套资料\配套资料/4.第四章 深度学习核⼼框架PyTorch/第五章:图像识别模型与训练策略(重点)/cat_to_name.json’, ‘r’) as f:
cat_to_name = json.load(f)
# print(cat_to_name)
‘’’
{‘21’: ‘fire lily’, ‘3’: ‘canterbury bells’, ‘45’: ‘bolero deep blue’, ‘1’: ‘pink primrose’, ‘34’: ‘mexican aster’, ‘27’: ‘prince of wales feathers’,
‘7’: ‘moon orchid’, ‘16’: ‘globe-flower’, ‘25’: ‘grape hyacinth’, ‘26’: ‘corn poppy’, ‘79’: ‘toad lily’, ‘39’: ‘siam tulip’, ‘24’: ‘red ginger’,
‘67’: ‘spring crocus’, ‘35’: ‘alpine sea holly’, ‘32’: ‘garden phlox’, ‘10’: ‘globe thistle’, ‘6’: ‘tiger lily’, ‘93’: ‘ball moss’, ‘33’: ‘love in the mist’,
‘9’: ‘monkshood’, ‘102’: ‘blackberry lily’, ‘14’: ‘spear thistle’, ‘19’: ‘balloon flower’, ‘100’: ‘blanket flower’, ‘13’: ‘king protea’, ‘49’: ‘oxeye daisy’,
‘15’: ‘yellow iris’, ‘61’: ‘cautleya spicata’, ‘31’: ‘carnation’, ‘64’: ‘silverbush’, ‘68’: ‘bearded iris’, ‘63’: ‘black-eyed susan’, ‘69’: ‘windflower’,
‘62’: ‘japanese anemone’, ‘20’: ‘giant white arum lily’, ‘38’: ‘great masterwort’, ‘4’: ‘sweet pea’, ‘86’: ‘tree mallow’, ‘101’: ‘trumpet creeper’,
‘42’: ‘daffodil’, ‘22’: ‘pincushion flower’, ‘2’: ‘hard-leaved pocket orchid’, ‘54’: ‘sunflower’, ‘66’: ‘osteospermum’, ‘70’: ‘tree poppy’, ‘85’: ‘desert-rose’,
‘99’: ‘bromelia’, ‘87’: ‘magnolia’, ‘5’: ‘english marigold’, ‘92’: ‘bee balm’, ‘28’: ‘stemless gentian’, ‘97’: ‘mallow’, ‘57’: ‘gaura’, ‘40’: ‘lenten rose’,
‘47’: ‘marigold’, ‘59’: ‘orange dahlia’, ‘48’: ‘buttercup’, ‘55’: ‘pelargonium’, ‘36’: ‘ruby-lipped cattleya’, ‘91’: ‘hippeastrum’, ‘29’: ‘artichoke’, ‘71’: ‘gazania’,
‘90’: ‘canna lily’, ‘18’: ‘peruvian lily’, ‘98’: ‘mexican petunia’, ‘8’: ‘bird of paradise’, ‘30’: ‘sweet william’, ‘17’: ‘purple coneflower’, ‘52’: ‘wild pansy’,
‘84’: ‘columbine’, ‘12’: “colt’s foot”, ‘11’: ‘snapdragon’, ‘96’: ‘camellia’, ‘23’: ‘fritillary’, ‘50’: ‘common dandelion’, ‘44’: ‘poinsettia’, ‘53’: ‘primula’,
‘72’: ‘azalea’, ‘65’: ‘californian poppy’, ‘80’: ‘anthurium’, ‘76’: ‘morning glory’, ‘37’: ‘cape flower’, ‘56’: ‘bishop of llandaff’, ‘60’: ‘pink-yellow dahlia’,
‘82’: ‘clematis’, ‘58’: ‘geranium’, ‘75’: ‘thorn apple’, ‘41’: ‘barbeton daisy’, ‘95’: ‘bougainvillea’, ‘43’: ‘sword lily’, ‘83’: ‘hibiscus’, ‘78’: ‘lotus lotus’,
‘88’: ‘cyclamen’, ‘94’: ‘foxglove’, ‘81’: ‘frangipani’, ‘74’: ‘rose’, ‘89’: ‘watercress’, ‘73’: ‘water lily’, ‘46’: ‘wallflower’, ‘77’: ‘passion flower’,
‘51’: ‘petunia’}
‘’’

迁移学习:使用前人的网络结构和模型做训练

数据量较小时:对模型进行微小改动,如冻住某一部分不进行迭代更新训练,只训练更新较少的网络层

数据量中等时:对模型进行改动,如冻住少量部分不进行迭代更新训练,其他部分进行迭代更新训练

数据量较大时,整个模型不进行冻结,全部更新训练

加载models中提供的模型,并且直接用训练好的权重当作初始化参数

#可选网络结构比较多 [‘resnet’, ‘alexnet’, ‘vgg’, ‘squeezenet’, ‘densenet’, ‘inception’],resnet网络效果较好
model_name = ‘resnet’

是否用人家训练好的特征来做

此项目冻结输出层前面所有部分,不进行训练更新

feature_extract = True

是否用GPU训练——————固定写法

train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print(‘CUDA is not available. Training on CPU …’)
else:
print(‘CUDA is available! Training on GPU …’)
device = torch.

  • 18
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值