【Few Shot】batch-based 和 episodic-based 两种训练导入输入的区别

前言

现在做的情感计算方向数据量还是比较小的,所以想着后面可能会做 few shot 相关的内容,在读论文的时候注意到了这两种训练模式,所以从代码的角度来记录一下二者的不同。

Batch-based

这就是和普通的训练一样的数据导入,从数据集中选择一个 batch size 大小的子集,扔到模型里面学习训练。首先是定义一个Dataset,构造函数的关键字参数包含了数据集相关的内容信息,可以自己指定。

class Dataset:
	def __init__(self, **kwargs):
		self.transform = build_transform()
		self.im_path = kwargs['im_path']
		self.labels = kwargs['labels']
		...
	
	def __getitem__(self, idx):
		im = Image.open(self.im_path[idx]).convert("RGB")
		im = self.transform(im)
		return im, self.labels[idx]
	
	def __len__(self):
		return len(self.labels)

然后利用上面建立的 Dataset 类创建一个 dataloader。

class Dataloader:
	def __init__(slef, *args, **kwargs):
		...
	
	def get_dataloader(self, **kwargs):
		transform = build_transform()
		dataset = Dataset(kwargs)
		dataloader = torch.utils.data.DataLoader(dataset, kwargs)
		return dataloader

确实和普通的是一样的。

Episodic-based

这个就主要用在小样本学习里面了,
首先定义 Dataset 类,这个和之前的也是一样,直接定义就行了。但是开源代码中在 Dataset 的 transform 中定义了采样支撑集和 query 的函数。

def extract_episode(n_support, n_query, d):
    # data: N x C x H x W
    n_examples = d["data"].size(0)

    if n_query == -1:
        n_query = n_examples - n_support

    example_inds = torch.randperm(n_examples)[: (n_support + n_query)]
    support_inds = example_inds[:n_support]
    query_inds = example_inds[n_support:]

    xs = d["data"][support_inds]
    xq = d["data"][query_inds]

    return {"class": d["class"], "xs": xs, "xq": xq}

然后在实例化 Dataset 的时候:

transform = [
	...
	partial(extract_episode, n_support, n_query)
]
dataset = Dataset(transform=transform, ...)

在参考的开源代码中,他们是定义了一个采样器 sampler,利用这个自定义的采样器在 DataLoader 中进行采样。

class EpisodicBatchSampler(object):
    def __init__(self, n_classes, n_way, n_episodes):
        self.n_classes = n_classes
        self.n_way = n_way
        self.n_episodes = n_episodes

    def __len__(self):
        return self.n_episodes

    def __iter__(self):
        for i in range(self.n_episodes):
            yield torch.randperm(self.n_classes)[: self.n_way]


dataloader = torch.utils.data.DataLoader(
		ds, batch_sampler=sampler, num_workers=0)

在采样器的 __iter__ 函数中,就是从总共的类别数中选择 n_way 个类别。这样就采样到了一份训练数据。

参考代码

prototypical-networks
CloserLookFewShot

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

A91A981E

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值