回归原型网络代码episode数据加载

一般 P y T o r c h PyTorch PyTorch加载数据的固定格式是:
dataset = MyDataset() : 构建 D a t a s e t Dataset Dataset对象
dataLoader = DataLoader(dataset) #通过 D a t a L o a d e r DataLoader DataLoader来构造迭代对象.
num_epoches = 100
for epoch in range(num_epoches): #逐步迭代数据
for img,label in dataLoader:
#训练代码
但是小样本有episode这个概念,所以需要额外用一个 s a m p l e r sampler sampler,写篇文章记录下原型网络是怎么加载数据哒.

episode有分 s u p p o r t support support集和 q u e r y query query集, 如果 s u p p o r t support support有N个类,每个类有 N N N个样本,我们就叫做 N w a y − K s h o t Nway-Kshot NwayKshot,另外,在 q u e r y query query集每个样类中有 Q Q Q个样本,注意着 Q Q Q个样本和 S u p p o r t Support Support集中 K K K个样本没有重复样本.

DataLoader是怎么获取数据的?

在这里插入图片描述
当我们使用下面代码获取一个batch数据时候, S a m p l e r Sampler Sampler会先产生64个下标,然后DataLoader会根据这64个下标获取数据,最后封装成一个 t e n s o r tensor tensor,返回给small_data
from torch.utils.data import DataLoader:
train_dataloader = DataLoader(training_data,batch_size = 64,shuffle = True)
from small_data in train_dataloader:
print(data.shape)

原型代码中的dataloader simpler和dataset分别对应于:

dataset:

omniglot.py

ds = TransformDataset(ListDataset(class_names), transforms)

Sampler

base.py

EpisodicBatchSampler

torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)

加载数据产生batch_size

Nway -Kshot-Qquery的episode

dataset

在原型代码中使用omnigolot数据集训练,在图片外面有两个文件夹,第一层是1种类,第二层是2编号,注意源来的 o m i g l o t omiglot omiglot数据集有区分:
images_background.zip和images_background.zip
代码作者分别解压之后,合并到了omniglot/data文件夹下,详细可以看

官方代码

download_omniglot.sh文件。
在这里插入图片描述
在官方代码中有对omiglot进行分割的txt文件,如下图所示:
在这里插入图片描述
可以在想要的训练集,验证集和测试集时分别读取不同的txt文件完成,以train为例, t r a i n . t x t train.txt train.txt的第一行为:

Angelic/character01/rot000
分别代表了1种类/编号/旋转角度.

ll5行,首先读取train.txt文件的所有行,将其存放至容器class_names = []中.
然后在119行

ds = TransformDataset(ListDataset(class_names), transforms)

查看官方文档和教程产生List形式数据

可以发现就是在获得ListDataset(class_names)中的任意一项数据时,
都会对其进行transform变换,transforms变换
内容如下:
trainsformers = [partial(convert_dict,‘class’),
load_class_images,
partial(extract_episode, n_support, n_query)]
其中convert_dict,load_class_images,extract_episode.均为代码作者自定义函数,而 p a r t i a l partial partial关键字就是在调用函数的同时,有几个参数固定为所给定的值,以partial(convert_dict,‘class’),就等同于将convert_dict函数从

def convert_dict(k, v):
    return { k: v }

转变为:

def convert_dict(v):
	# 此时已经不用传k的值了,因为 partial(convert_dict, 'class') 已经给k 赋值了'class'
    return { k: v }

如果有functional(a,b,c,d,e,f,g),则:

  • partial(convert_dict, ‘value1’) 就已经给参数a传递了 v a l u e 1 value1 value1
  • partial(convert_dict, ‘value1’, ‘value2’) 就已经给参数a传了value1、参数b传了value2
  • partial(convert_dict, ‘value1’, ‘value2’ ,‘value3’) 就已经给参数a传了value1、参数b传了value2、参数c传了value3
    等等等>
    以Avesta/character12/rot890为例,我们来看transforme最终能取得什么效果,在原型代码中设置 e p i s o d e episode episode,为 5 w a y − 5 s h o t − 15 q u e r y 5way-5shot-15query 5way5shot15query
    为了区分两个5, 我们在这以 10 w a y − 5 s h o t − 15 q u e r y 10way-5shot-15query 10way5shot15query
    进行举例子.

transforms = [partial(convert_dict, ‘class’),
load_class_images,
partial(extract_episode, 5, 15)]
首先把Avesta/character/rot090传递给参数partial(convert_dict, ‘class’)

def convert_dict(k, v):
    return { k: v }

返回结果为:{ ‘class’ :“Avesta/character12/rot090” }
然后把{ ‘class’ :“Avesta/character12/rot090” }传进load_class_images

def load_class_images(d):  # { 'class' :"Avesta/character12/rot090" }
    if d['class'] not in OMNIGLOT_CACHE:
        alphabet, character, rot = d['class'].split('/')  
        # 值分别为 Avesta  character12  rot090
        
        image_dir = os.path.join(OMNIGLOT_DATA_DIR, 'data', alphabet, character)
        # OMNIGLOT_DATA_DIR是omniglot的根路径,此句就是为了拼接出Avesta/character12的路径

        class_images = sorted(glob.glob(os.path.join(image_dir, '*.png')))
        # 获得路径下以png结尾的文件路径,然后排序,这是一个列表

        if len(class_images) == 0:
            raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(d['class'], image_dir))

        image_ds = TransformDataset(ListDataset(class_images), # 这个同上文讲过的,不再赘述
                                    compose([partial(convert_dict, 'file_name'),
                                             partial(load_image_path, 'file_name', 'data'),
                                             partial(rotate_image, 'data', float(rot[3:])),
                                             partial(scale_image, 'data', 28, 28),
                                             partial(convert_tensor, 'data')]))

        loader = torch.utils.data.DataLoader(image_ds, batch_size=len(image_ds), shuffle=False)
        # 将全部数据封装成一个batch,作为一个episode中的一个类

        for sample in loader:
            OMNIGLOT_CACHE[d['class']] = sample['data']
            break # only need one sample because batch size equal to dataset length

    return { 'class': d['class'], 'data': OMNIGLOT_CACHE[d['class']] }

返回结果为:

```python
{ 'class' :"Avesta/character12/rot090" ,
'data':  size为(20,1,28,28)的一个tensor
}

注意这里的(20,1,28,28)是固定的,因为一个文件夹下面只有20张 28 ∗ 28 28*28 2828的黑白图片
然后在把上面结果传递给extract_episode:注意 10 w a y − 5 s h o t − 15 q u e r y 10way-5shot-15query 10way5shot15query进行举例
所以每个类有5个作为 s u p p o r t support support 15个作为 q u e r y query query

def extract_episode(n_support, n_query, d):
    # data: N x C x H x W
    n_examples = d['data'].size(0)  # 20

    if n_query == -1:
        n_query = n_examples - n_support  

    example_inds = torch.randperm(n_examples)[:(n_support+n_query)]
    # 从20个样本中 选取5+15个
    
    support_inds = example_inds[:n_support]
    # 从中选5个作为support
    
    query_inds = example_inds[n_support:]
    # 剩下的20-5个作为query

	# 根据下标加载数据
    xs = d['data'][support_inds]
    xq = d['data'][query_inds]

    return {
        'class': d['class'],
        'xs': xs,
        'xq': xq
    }

返回结果为

{
	'class': "Avesta/character12/rot090" ,
	'xs': size为(5,1,28,28)的一个tensor ,
	'xq': size为(15,1,28,28)的一个tensor
}

Simpler

根据 E p i s o d i c B a t c h S a m p l e r 的定义和调用 EpisodicBatchSampler的定义和调用 EpisodicBatchSampler的定义和调用,可以根据需求生成类的下标:

class EpisodicBatchSampler(object):
    def __init__(self, n_classes, n_way, n_episodes):
        self.n_classes = n_classes
        self.n_way = n_way
        self.n_episodes = n_episodes

    def __len__(self):
        return self.n_episodes

    def __iter__(self):
        for i in range(self.n_episodes):
            yield torch.randperm(self.n_classes)[:self.n_way]

sampler = EpisodicBatchSampler(len(ds), n_way, n_episodes)
此时的len(ds)等于train.txt行数,

解释

此时的 l e n ( d s ) len(ds) len(ds)等于train.txt行数,也就是4112行,

  • n_way等于10,.
  • n_episodes等于100
  • 就是一个epoch里面有100个episode.
  • 每个episode有10个类,从采样定义,我们可以发现,认为字符和转换某个角度(90、180、270)后的字符,是不i一样的,可以看作是一种数据增强吧,另外 S a m p l e r Sampler Sampler的代码显示其返回的list有10个下标,

dataloader

然后 d a t a l o a d e r dataloader dataloader根据它提供的10个下标,去 d a t a s e t dataset dataset找对应下标的数据.

{
	'class': "Avesta/character12/rot090" ,
	'xs': size为(5,1,28,28)的一个tensor ,
	'xq': size为(15,1,28,28)的一个tensor
}

然后dataloader将10个数据合并成一个episode.
在lin35行使用了dataloader:

https://github.com/jakesnell/prototypical-networks/blob/c9bb4d258267c11cb6e23f0a19242d24ca98ad8a/protonets/engine.py#L35

其中state[‘loader’]定义在lin39行,lin39
tqdm相当于有进度条的 f o r for for循环.
在这里插入图片描述
因此,功能上可以看作:

for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
	...
# 等价于
for sample in train_loader:
	...


调试分析sampler可以看作和推断一致.
在这里插入图片描述

总结

会自己将代码研究透彻,构造各种数据框架,会自己研究代码,将其全部都搞定都行啦的理由与打算.

  • 慢慢的会自己将代码都给其弄明白,全部都将其搞透彻,研究彻底都行啦的里由与打算.
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

big_matster

您的鼓励,是给予我最大的动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值