Few Shot Classification小知识——数据集的加载

概述

Few-shot classification(小样本分类)是机器学习和人工智能的一个子领域,解决的问题是在训练数据非常有限的情况下,学习对新样本进行分类。在传统的监督学习中,模型需要在包含大量标记样本的数据集上进行训练,每个类别都有丰富的标记样本。然而,在实际应用中,获得如此大量的标记数据可能会非常困难或昂贵。
目前网上对于入门few shot的十分少,博主之前对于episode这些也十分不明白,在看了一些资料和代码后才逐渐理解小样本是怎样进行训练的。对此,博主首先对其中的数据集加载部分进行了总结,希望能够对各位读者有一些启发。

步骤

1.修改文件结构

- data_name
--- images
----- folder_name1
------- img1.png
------- img2.png
----- folder_name2
--- meta
----- classes.txt  
----- fsl_train.txt
----- fsl_test.txt
----- fsl_train_class.txt
----- fsl_test_class.txt

其中folder_name1,folder_name2是文件夹的名字,通常是分类名称,有些可能也是下标数字(1-100的数字)

2.找到图像的标签文件classes.txt

classes.txt里面含有图像全部的类别,如果没有需要自己构建一个,标签文件的内容大致如下:

class_name1
class_name2
class_name3

3.使用代码生成文件

生成的文件包括:fsl_train.txt,fsl_test.txt,fsl_train_class.txt,fsl_test_class.txt文件
该代码目前支持的情况有:

  • (1) folder_name为类别名称
  • (2) folder_name为类别名称对应的下标,从1开始
  • (3) folder_name文件夹下面的图片名称全部是数字,没有其他符号
  • (4) folder_name文件夹下面的图片名称什么符号都有

文件大致内容为:fsl_train.txt:
在这里插入图片描述
fsl_train_class.txt
在这里插入图片描述

代码为:

def make_file(img_root_path, names, path, is_num):
    """
    :param img_root_path: 图像文件夹
    :param names: 对应的图像文件名称
    :param path: 要保存的路径
    :param is_num: 图像文件名称是否是数字
    """
    with open(path,"w") as f:
        for name in names:
            img_dir = os.path.join(img_root_path,str(name))
            img_names = os.listdir(img_dir)
            if is_num:
                sort_img_names = sorted(img_names,key=lambda s: int(s.split('.')[0]))
            else:
                sort_img_names = sorted(img_names)
            for img_name in sort_img_names:
                img_path = os.path.join(img_dir,img_name).replace(img_root_path + "/","")
                f.write(f"{img_path}\n")
            
def generate_split_dataset(data_root, train_num, is_imgs_id, is_img_name_num):
    """
    :param data_root: 数据集目录
    :param train_num: 用于训练的类别数目
    :param is_imgs_id: 图像文件夹名称是否是下标
    :param is_img_name_num: 图像名字是否是数字 
    :return: None
    """
    class_path = os.path.join(data_root,"meta", "classes.txt")
    class_list = list_from_file(class_path)
    if is_imgs_id:
    	# 下标从1开始,可以根据自己的需要修改
        id2class = {i + 1 : _class for i, _class in enumerate(class_list)}
    else:
        id2class = {i: _class for i, _class in enumerate(class_list)}
    # class2id = {_class : i + 1 for i, _class in enumerate(class_list)}
    # 选择train_num个类作为训练集的,其他作为测试的
    train_class_ids = random.sample(range(1, len(class_list) + 1),train_num)
    test_class_ids = []
    for id in range(1, len(class_list) + 1):
        if id not in train_class_ids:
            test_class_ids.append(id)
    # 获得images文件夹的名称
    if is_imgs_id:
        train_class_name = train_class_ids
        test_class_name = test_class_ids
    else:
        train_class_name = [id2class[id] for id in train_class_ids]
        test_class_name = [id2class[id] for id in test_class_ids]
    # 顺序排序
    train_class_name = sorted(train_class_name)
    test_class_name = sorted(test_class_name)
    train_class_save_path = os.path.join(data_root, "meta", "fsl_train_class.txt")
    test_class_save_path = os.path.join(data_root, "meta" , "fsl_test_class.txt")
    with open(train_class_save_path, "w") as f:
        for cls_name in train_class_name:
            f.write(f"{str(cls_name)}\n")

    with open(test_class_save_path, "w") as f:
        for cls_name in test_class_name:
            f.write(f"{str(cls_name)}\n")

    # 将这些数据保存在fsl_train.txt中,格式为:class_name/img_name
    img_root_path = os.path.join(data_root,"images")
    train_imgs_name_path = os.path.join(data_root, "meta", "fsl_train.txt")
    test_imgs_name_path = os.path.join(data_root, "meta", "fsl_test.txt")
    make_file(img_root_path, train_class_name, train_imgs_name_path,is_img_name_num)
    make_file(img_root_path, test_class_name,test_imgs_name_path, is_img_name_num)

4.构建basedataset类

basedataset类是一个用于加载含有类别名称的文件,代码为:

import copy
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Mapping, Optional, Sequence, Union
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os.path as osp
from PIL import Image
import torch

from util import tools

from mmpretrain.evaluation import Accuracy

class BaseFewShotDataset(Dataset, metaclass=ABCMeta):
    def __init__(self,
                 pipeline,
                 data_prefix: str,
                 classes: Optional[Union[str, List[str]]] = None,
                 ann_file: Optional[str] = None) -> None:
        super().__init__()

        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.pipeline = pipeline
        self.CLASSES = self.get_classes(classes)
        self.data_infos = self.load_annotations()
        self.data_infos_class_dict = {i: [] for i in range(len(self.CLASSES))}
        for idx, data_info in enumerate(self.data_infos):
            self.data_infos_class_dict[data_info['gt_label'].item()].append(
                idx)

    def load_image_from_file(self,info_dict):
        img_prefix = info_dict['img_prefix']
        img_name = info_dict['img_info']['filename']
        img_file = osp.join(img_prefix,f"{img_name}")
        img_data = Image.open(img_file).convert('RGB')
        return img_data

    @abstractmethod
    def load_annotations(self):
        pass

    @property
    def class_to_idx(self) -> Mapping:
        return {_class: i for i, _class in enumerate(self.CLASSES)}

    def prepare_data(self, idx: int) -> Dict:
        results = copy.deepcopy(self.data_infos[idx])
        imgs_data = self.load_image_from_file(results)
        data = {
            "img" : self.pipeline(imgs_data),
            "gt_label" : torch.tensor(self.data_infos[idx]['gt_label'])
        }
        return data

    def sample_shots_by_class_id(self, class_id: int,
                                 num_shots: int) -> List[int]:
        all_shot_ids = self.data_infos_class_dict[class_id]
        return np.random.choice(
            all_shot_ids, num_shots, replace=False).tolist()

    def __len__(self) -> int:
        return len(self.data_infos)

    def __getitem__(self, idx: int) -> Dict:
        return self.prepare_data(idx)

    @classmethod
    def get_classes(cls,
                    classes: Union[Sequence[str],
                                   str] = None) -> Sequence[str]:
        if isinstance(classes, str):
            class_names = tools.list_from_file(classes)
        elif isinstance(classes, (tuple, list)):
            class_names = classes
        else:
            raise ValueError(f'Unsupported type {type(classes)} of classes.')

        return class_names

5.构建通用的少样本数据集加载类UniversalFewShotDataset

该文件的作用主要是将数据从标签文件中拿出来,加载数据。
代码如下:

from datasets.base import BaseFewShotDataset
from typing_extensions import Literal
from typing import List, Optional, Sequence, Union
from util import tools
import os
import os.path as osp
import numpy as np
import torchvision.transforms as transforms
class UniversalFewShotDataset(BaseFewShotDataset):
    def __init__(self,
                 img_size,
                 subset: Literal['train', 'test', 'val'] = 'train',
                 *args,
                 **kwargs):
        if isinstance(subset, str):
            subset = [subset]
        for subset_ in subset:
            assert subset_ in ['train', 'test', 'val']
        self.subset = subset
        self.file_format = file_format
        # 归一化参数
        norm_params = {'mean': [0.485, 0.456, 0.406],
                       'std': [0.229, 0.224, 0.225]}
        # 对数据进行处理
        if subset[0] == 'train':
            pipeline = transforms.Compose([
                transforms.RandomResizedCrop(img_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4),
                transforms.ToTensor(),
                transforms.Normalize(**norm_params)
        ])
        elif subset[0] == 'test':
            pipeline = transforms.Compose([
                transforms.Resize(size=int(img_size * 1.15)),
                transforms.CenterCrop(size=img_size),
                transforms.ToTensor(),
                transforms.Normalize(**norm_params)
            ])
        super().__init__(pipeline=pipeline, *args, **kwargs)

    def get_classes(
            self,
            classes: Optional[Union[Sequence[str], str]] = None) -> Sequence[str]:
        class_names = tools.list_from_file(classes)
        return class_names
    
	# 加载标签文件
    def load_annotations(self) -> List:
        data_infos = []
        ann_file = self.ann_file
        with open(ann_file) as f:
            for i, line in enumerate(f):
                class_name, filename = line.strip().split('/')
                gt_label = self.class_to_idx[class_name]
                info = {
                    'img_prefix':
                    osp.join(self.data_prefix, 'images', class_name),
                    'img_info': {
                        'filename': filename
                    },
                    'gt_label':
                    np.array(gt_label, dtype=np.int64)
                }
                data_infos.append(info)
        return data_infos

6.构建针对元学习的数据集加载类EpisodicDataset

代码如下:

import numpy as np
from torch import Tensor
from torch.utils.data import Dataset,DataLoader
from functools import partial
import os.path as osp
from typing import Mapping
from util import tools
import json
class EpisodicDataset:
    def __init__(self,
                 dataset: Dataset,
                 num_episodes: int,
                 num_ways: int,
                 num_shots: int,
                 num_queries: int,
                 episodes_seed: int):
        self.dataset = dataset
        self.num_ways = num_ways
        self.num_shots = num_shots
        self.num_queries = num_queries
        self.num_episodes = num_episodes
        self._len = len(self.dataset)
        self.CLASSES = dataset.CLASSES
        self.episodes_seed = episodes_seed
        self.episode_idxes, self.episode_class_ids = \
            self.generate_episodic_idxes()

    def generate_episodic_idxes(self):
        """Generate batch indices for each episodic."""
        episode_idxes, episode_class_ids = [], []
        class_ids = [i for i in range(len(self.CLASSES))]
        # 这一句可以不用
        with tools.local_numpy_seed(self.episodes_seed):
            for _ in range(self.num_episodes):
                np.random.shuffle(class_ids)
                # sample classes
                sampled_cls = class_ids[:self.num_ways]
                episode_class_ids.append(sampled_cls)
                episodic_support_idx = []
                episodic_query_idx = []
                # sample instances of each class
                for i in range(self.num_ways):
                    shots = self.dataset.sample_shots_by_class_id(
                        sampled_cls[i], self.num_shots + self.num_queries)
                    episodic_support_idx += shots[:self.num_shots]
                    episodic_query_idx += shots[self.num_shots:]
                episode_idxes.append({
                    'support': episodic_support_idx,
                    'query': episodic_query_idx
                })
        return episode_idxes, episode_class_ids

    def __getitem__(self, idx: int):
        support_data = [self.dataset[i] for i in self.episode_idxes[idx]['support']]
        query_data = [self.dataset[i] for i in self.episode_idxes[idx]['query']]
        return {
            'support_data':support_data,
            'query_data':query_data
        }

    def __len__(self):
        return self.num_episodes

    def evaluate(self, *args, **kwargs):
        return self.dataset.evaluate(*args, **kwargs)

    def get_episode_class_ids(self, idx: int):
        return self.episode_class_ids[idx]

7.构建自己的配置文件,如:json格式

配置文件除了json,也可以是其他形式的,这里以json格式为例:

{
    "train":{
        "num_episodes":2000,
        "num_ways":10,
        "num_shots":5,
        "num_queries":5,
        "episodes_seed":1001,
        "per_gpu_batch_size":1,
        "per_gpu_workers": 8,
        "epoches": 160,
        "dataset":{
            "name": "vireo_172",
            "img_size": 224,
            "data_prefix":"/home/gaoxingyu/dataset/vireo-172/",
            "classes":"/home/gaoxingyu/dataset/vireo-172/meta/fsl_train_class.txt",
            "ann": "/home/gaoxingyu/dataset/vireo-172/meta/fsl_train.txt"
        }
    }
}

8.编写主程序,进行测试

代码如下:

with open("config.json", 'r', encoding='utf-8') as f:
     f = f.read()
     configs = json.loads(f)
     logger.info(f"Experiment Setting:{configs}")
# 创建数据集
## train_dataset
train_food_dataset = UniversalFewShotDataset(data_prefix=configs['train']['dataset']['data_prefix'],
                         subset="train", classes=configs['train']['dataset']['classes'],
                         img_size=configs['train']['dataset']['img_size'],ann_file=configs['train']['dataset']['ann'])
train_dataset = EpisodicDataset(dataset=train_food_dataset,
                                num_episodes=configs['train']['num_episodes'],
                                num_ways=configs['train']['num_ways'],
                                num_shots=configs['train']['num_shots'],
                                num_queries=configs['train']['num_queries'],
                                episodes_seed=configs['train']['episodes_seed'])
## train dataloader
train_samper = torch.utils.data.distributed.DistributedSampler(train_dataset, rank = local_rank, shuffle=True)
train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=configs['train']['per_gpu_batch_size'],
    sampler=train_samper,
    num_workers=configs['train']['per_gpu_workers'],
    collate_fn=partial(collate, samples_per_gpu=1),
    worker_init_fn=worker_init_fn,
    drop_last=True
)
for data in train_data_loader:
	print(data)
  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lzl2040

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值