Pytorch框架下的语义分割实战(一,数据集处理)

认真学习了这位博主ZJE_ANDY (下文称Z博,如有冒犯,请原谅)的语义分割项目,感谢感谢!!

pytorch用FCN语义分割手提包数据集(训练+预测单张输入图片代码)

分享一下笔记,超详细哦!

首先来看一下dataset.py

Z博整理的数据集有训练集原图(放在了last文件夹下)和训练集标签图(放在last_mask文件夹下),数据集的前期整理代码文件名为BagData.py,后期只需要改一下文件目录就可以啦,多方便呢。。。

将代码附在这里,添加了些注释。

'''
BagData.py
'''
import os
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import numpy as np
import cv2


#transform是对图像进行预处理、数据增强等。Compose将多个处理步骤整合到一起。
#ToTensor:将原始取值0-255像素值,归一化为0-1
#Normalize:用像素值的均值和标准偏差对像素值进行标准化
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

def onehot(data, n):
    buf = np.zeros(data.shape + (n, ))
    nmsk = np.arange(data.size)*n + data.ravel()
    buf.ravel()[nmsk-1] = 1
    return buf

class BagDataset(Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        
    def __len__(self):
        return len(os.listdir('./bags/last'))

    def __getitem__(self, idx):
        #读取原图
        img_name = os.listdir('./bags/last')[idx]
        imgA = cv2.imread('./bags/last/'+img_name)
        imgA = cv2.resize(imgA, (160, 160))
        
        #读取标签图,即二值图
        imgB = cv2.imread('/bags/bags/last_msk/'+img_name, 0)
        imgB = cv2.resize(imgB, (160, 160))

        imgB = imgB/255
        imgB = imgB.astype('uint8')
        imgB = onehot(imgB, 2) #因为此代码是二分类问题,即分割出手提包和背景两样就行,因此这里参数是2
        imgB = imgB.transpose(2,0,1) #imgB不经过transform处理,所以要手动把(H,W,C)转成(C,H,W)
        imgB = torch.FloatTensor(imgB)
        if self.transform:
            imgA = self.transform(imgA) #一转成向量后,imgA通道就变成(C,H,W)
        return imgA, imgB

bag = BagDataset(transform)
train_size = int(0.9 * len(bag))    #整个训练集中,90%为训练集
test_size = len(bag) - train_size

train_dataset, test_dataset = random_split(bag, [train_size, test_size]) #按照上述比例(9:1)划分训练集和测试集
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)

if __name__ =='__main__':
    for train_batch in train_dataloader:
        print(train_batch)

    for test_batch in test_dataloader:
        print(test_batch)

下面按照代码顺序,讲解某些语句的含义和作用,如有不当,欢迎指出丫。。。

①transform

torchvision中的transform是对图像进行预处理、数据增强等。
Compose将多个处理步骤整合到一起。 ToTensor:将原始取值0-255像素值,归一化为0-1。 Normalize:用像素值的均值和标准偏差对像素值进行标准化。

②One-Hot编码,又称一位有效编码。

主要采用N位寄存器对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候只有一位有效。此编码

  • 11
    点赞
  • 70
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值