图像识别必备模块ImageFolder使用解析与实战

一. 使用解析

1. ImagFolder导入模块 datasets&transforms

from torchvision import datasets,transforms

2. 读取对应文件夹下图片并将图片转为(C,H,W)的tensor格式(需要借助transform.ToTensor()

读取data_dir/x文件夹下对应的图片并经过data_transforms的转换

注: 在PyTorch中,图像的形状通常是 (C, H, W),但在matplotlib等图像处理库中,图像的形状通常是 (H, W, C)


#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ToTensor()
])
 
dataset=ImageFolder('./data/train',transform=transform)

3. dataset.classes可输出由图片标签的列表

二. 实战读取文件夹下训练与测试图片数据集

1. 导入模块

from torchvision import datasets,transforms

2. 设置好对应文件夹

data_dir = './flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

3. 设置好train与valid所需的data_transforms

data_transforms = {
    'train': 
        transforms.Compose([
        transforms.Resize([96, 96]),
        transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ]),
    'valid': 
        transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

 4. 使用ImageFolder导入对应图片数据集(datasets为字典形式)

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x), 
                     data_transforms[x]) for x in ['train', 'valid']}

5.  使用dataLoader作为迭代器,读取图片标签列表

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes
print(class_names)

 

  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值