Fast AI1.0 自定义数据集

  1. 加载必要的包
    fastai.vision 包含处理计算机视觉问题常用的类与方法:比如图像类image, 数据增强类transform, 通过预训练网络来快速搭建模型的create_cnn函数等等。
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
  1. 导入数据集摘要信息
    数据集的摘要信息存储再以csv结尾的无格式文件种,第一列Image表示图片的名称,第二列Id表示图片的标签,如下所示。
path = Path('../data')
trn_imgs = pd.read_csv(path/'train.csv'); 
trn_imgs.head()

training set abstract

  1. 定义TripleImageItem
    我们希望定义一个Item,其包含3张图像,其中第一张图像通常称为anchor, 第二张图像与第一张图像为同一类别,第三张图像与第一张图像为不同类别。
    自定义Item时,最重要的是obj和data属性,obj属性描述这个item,并且依靠obj可以复制item;data是实际用于学习器学习的数据。也就是说对obj进行简单的预处理,然后将其赋值给data.
class TripleImageItem(ItemBase):
    def __init__(self, img1, img2, img3, mean, std): 
        """
            img1, img2, img3 类型为Image, 且shape相同
        """
        self.img1, self.img2, self.img3 = img1, img2, img3
        self.mean, self.std = mean, std
        self.obj = (img1, img2, img3)
        self.data = self.normalize(img1.data, img2.data, img3.data)
        
    def apply_tfms(self, tfms,*args, **kwargs):
        self.img1 = self.img1.apply_tfms(tfms, *args, **kwargs)
        self.img2 = self.img2.apply_tfms(tfms, *args, **kwargs)
        self.img3 = self.img3.apply_tfms(tfms, *args, **kwargs)
        self.data = self.normalize(self.img1.data, self.img2.data, self.img3.data)
        return self
        
    def normalize(self, img1, img2, img3):
        img1 = (img1-self.mean)/self.std
        img2 = (img2-self.mean)/self.std
        img3 = (img3-self.mean)/self.std
        return img1, img2, img3
        
    def __repr__(self): return f'{self.__class__.__name__} {self.img1.shape, self.img2.shape, self.img3.shape}'
    
    def to_one(self): 
        tmp1 = self.data[0]*self.std+self.mean
        tmp2 = self.data[1]*self.std+self.mean
        tmp3 = self.data[2]*self.std+self.mean
        return Image(torch.cat([tmp1, tmp2, tmp3],2))
        
    def show(self): return self.to_one().show()
  1. 定义TripleItemList
    TripleItemList是一个项表,在这里我们主要关注get方法,该方法描述如何生成一个item的必要信息,并利用前面定义的TripleItem将其包装为一个项。
class TripleImageItemList(ImageItemList):
    def __init__(self, items, **kwargs):
        self.kwargs = kwargs
        super().__init__(items, **kwargs)
        
        self.mean = torch.from_numpy(np.array([0.464589, 0.493298, 0.527568])).reshape((-1,1,1)).float()
        self.std = torch.from_numpy(np.array([0.287433, 0.282659, 0.285415])).reshape((-1,1,1)).float()
        
    def __len__(self):	return len(self.items)
    
    def __getitem__(self,idxs:int):
        if isinstance(try_int(idxs), int): return self.get(idxs)
        else: return self.new(items=self.items[idxs], xtra=self.xtra.iloc[idxs].copy())
        
    def new(self, items, **kwargs): return TripleImageItemList(items, **kwargs)
    
    def get(self, i):
        cur_label = self.xtra['Id'].iloc[i]
        img = super().get(i)
        
        img_pos_index = np.where(self.xtra['Id'].values==cur_label)[0]
        img_pos_index = img_pos_index[np.random.randint(0,len(img_pos_index), size=None, dtype=np.int)]
        img_pos = super().get(img_pos_index)
        
        img_neg_index = np.where(self.xtra['Id'].values!=cur_label)[0]
        img_neg_index = img_neg_index[np.random.randint(low=0, high=len(img_neg_index), size=None, dtype=np.int)]
        img_neg = super().get(img_neg_index)
        
        return TripleImageItem(img, img_pos, img_neg, self.mean, self.std)
        
    def reconstruct(self, t): 
        t[0] = t[0]*self.std+self.mean
        t[1] = t[1]*self.std+self.mean
        t[2] = t[2]*self.std+self.mean
        return TripleImageItem(t[0], t[1], t[2], self.mean, self.std)
        
    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(9,10), **kwargs):
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        fig.suptitle('TripleItems', weight='bold', size=14)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            xs[i].to_one().show(ax=ax, y=ys[i], **kwargs)
        plt.tight_layout()
        
    @classmethod
    def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr='.', suffix:str='', **kwargs):
        res = super().from_df(df, path=path, cols=cols, folder=folder, suffix=suffix, **kwargs)
        return res

  1. 生成数据集实例
from fastai.data_block import CategoryList

IMAGE_SIZE = 224
# 构造项表,此时trn_items的类型为TripleImageList
trn_items = TripleImageItemList.from_df(df=trn_imgs, path=path, folder='train', cols=0)

# 对项表随机划分为训练集、验证集,此时trn_val_items的类型为ItemLists
trn_val_items = trn_items.random_split_by_pct(valid_pct=0.1)

# 对ItemLists添加标记得到有标记的数据集,此时tran_val_ds的类型为LabelLists
trn_val_ds = trn_val_items.label_from_df(cols='Id', label_cls=CategoryList)

# 利用get_transforms产生数据增强变换序列,返回值为两个序列,第一个序列为数据增强的序列,用于训练集;
# 		  第二个序列只是常规的预处理预处理变换序列,用于验证集和测试集;
whl_tfms = get_transforms(do_flip=False, max_zoom=1, max_warp=0, max_rotate=2)
data = trn_val_ds.transform(whl_tfms, size=IMAGE_SIZE, padding_mode= 'border', resize_method=ResizeMethod.SQUISH)

# 将数据集组合成databunch
data = data.databunch(bs=32, num_workers=12, path=Path('.'))
  1. 结果展示
data.show_batch(rows=3, figsize=(16,8), ds_type=DatasetType.Train)

one batch

data.show_batch(rows=3, figsize=(16,8), ds_type=DatasetType.Valid)

valid dataset

参考资料:https://docs.fast.ai/tutorial.itemlist.html

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值