【pytorch】关于OpenCV和PIL.Image读取图片的区别

数据使用方角度
首先从pytorch出发,torchvision.transforms()要求传入的图像是PIL.Image格式(通道要求是RGB格式的),另外模型处理输入要转换为[1,channel,H,W];

所以最终导入torchvision.transforms()的图像格式需要转成PIL.Image,且需要在转换后增加batch维度([channel,H,W]变成[1,channel,H,W])

# 制作dataset
class MyDataset(Dataset):
    def __init__(self,data_path,label_path,transforms=None):
        self.data_path = data_path
        self.label_path = label_path
        self.transform = transforms
        
        self.images = sorted(os.listdir(data_path))
         #这里的label_path就是csv文件
        self.label= pd.read_csv(self.label_path)['label']
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self,idx):
        img_path = os.path.join(self.data_path , self.images[idx])  # 将image_dir与图像列表中的每张图的名字连接成地址
        
        #这的标签是一个数字表示的类别,所以不用像图片一样进行操作。直接使用self.label来进行读取即可
         
#         image_data = cv2.imread(img_path)
        image_data = Image.open(img_path)
        
        if self.transform is not None :
            image_data = self.transform(image_data)
        
        label_data = self.label[idx]
      
        return image_data,label_data
# values from ImageNet, recommended by PyTorch
transform_mean = [0.485, 0.456, 0.406]
transform_std = [0.229, 0.224, 0.225]

trans = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=transform_mean, std=transform_std),
])
#划分数据集,将原来的训练集划分为  训练集和验证集
dataset = MyDataset(data_path,label_path,trans)
train_dataset,valid_dataset = random_split(dataset, [0.8,0.2])

#创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2)
valid_loader = DataLoader(valid_dataset,batch_size=4,shuffle=True,num_workers=2)
print(len(train_dataset))
print(len(valid_dataset))

上面这个代码 image_data = cv2.imread(img_path)会报错,原因是

  • torchvision.transforms.Resize预期接收的输入类型是PIL图像,但是它收到了一个NumPy数组。这导致了类型错误。

    具体来说,在你的数据集类中,你使用了OpenCV
    (cv2)库读取图像数据。OpenCV读取的图像数据是NumPy数组,而不是PIL图像。因此,直接将这些NumPy数组传递给torchvision.transforms.Resize时,就会导致类型不匹配的错误。

总结:

具体来说这两种方式读取图片没有大的区别,但是建议用image_data = Image.open(img_path)的方式来读。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

超好的小白

没体验过打赏,能让我体验一次吗

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值