一般pytorch加载数据的固定格式是:
dataset = MyDataset() # 第一步:构造Dataset对象
dataloader = DataLoader(dataset)# 第二步:通过DataLoader来构造迭代对象
num_epoches = 100
for epoch in range(num_epoches):# 第三步:逐步迭代数据
for img, label in dataloader:
# 训练代码
但是小样本有episode
这个概念,所以需要额外用到一个sampler
。写篇文章记录下原型网络是怎么进行数据加载。
episode
有分supprt
集和query
集,如果supprt
集中有
N
N
N个类,每个类有
K
K
K个样本,我们就交做
N
w
a
y
−
K
s
h
o
t
N way -K shot
Nway−Kshot。另外在query
集每个类中有
Q
Q
Q个样本,注意这
Q
Q
Q个样本和supprt
集中
K
K
K个样本没重复的样本。
文章目录
DataLoader是怎么获得数据的
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系简单来说,就是如下图所示
当我们使用下面代码获取一个batch
的数据的时候,sampler
会先产生64个下标,然后Dataset会根据这64个下标获得数据,最后封装成一个tensor,返回给small_data
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
for small_data in train_dataloader:
print(data.shape)
原型代码中的dataloader、simpler和dataset分别对应于
dataset
:omniglot.py的
ds = TransformDataset(ListDataset(class_names), transforms)
simpler
: base.py的
EpisodicBatchSampler
dataloader
:omniglot.py的
torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)
N
w
a
y
−
K
s
h
o
t
−
Q
q
u
e
r
y
N way -K shot - Q query
Nway−Kshot−Qquery的episode
dataset
在原型代码中用omniglot
数据集训练,在图片外面有两个文件夹,第一层是①种类,第二层是②编号。注意原来的omniglot
数据集有区分images_background.zip
和images_background.zip
,代码作者分别解压之后,合并到了omniglot/data
文件夹下,详细可以看官方代码的download_omniglot.sh
文件。
在官方代码中有对omniglot
进行分割的txt文件,如下图所示。
可以在想要训练集、验证集和测试集时分别读取不同的txt文件完成。以train为例,train.txt
的第一行为:
Angelic/character01/rot000
分别代表了①种类/②编号/旋转的角度。
line115首先读取train.txt
的所有行,将它保存至class_names = [ ]
中.
然后在119行
ds = TransformDataset(ListDataset(class_names), transforms)
先用torchnet.dataset.ListDataset
将class_names
转换成ListDataset
,
查看官方文档和教程_产生List形式数据,可以发现就是在获得ListDataset(class_names)
中的任意一项数据时,都会对它、数据进行transforms变换,内容如下:
transforms = [partial(convert_dict, 'class'),
load_class_images,
partial(extract_episode, n_support, n_query)]
其中convert_dict、load_class_images、extract_episode均为代码作者自定义的函数。而partial
关键字就是在调用函数的同时,有几个参数固定为所给定的值。以partial(convert_dict, 'class')
为例,就等同于将convert_dict
函数从
def convert_dict(k, v):
return { k: v }
转变为
def convert_dict(v):
# 此时已经不用传k的值了,因为 partial(convert_dict, 'class') 已经给k 赋值了'class'
return { k: v }
如果有function(a,b,c,d,e,f,g)
,则
partial(convert_dict, 'value1')
就已经给参数a
传了value1
partial(convert_dict, 'value1', 'value2')
就已经给参数a
传了value1
、参数b
传了value2
partial(convert_dict, 'value1', 'value2' ,'value3')
就已经给参数a
传了value1
、参数b
传了value2
、参数c
传了value3
- 等等等等
以Avesta/character12/rot090
为例,我们来看transforms最终能取得什么效果。在原型代码中设置的episode
为
5
w
a
y
−
5
s
h
o
t
−
15
q
u
e
r
y
5way - 5shot - 15query
5way−5shot−15query,为了区分两个5,我们在这以
10
w
a
y
−
5
s
h
o
t
−
15
q
u
e
r
y
10way - 5shot - 15query
10way−5shot−15query进行举例。
transforms = [partial(convert_dict, 'class'),
load_class_images,
partial(extract_episode, 5, 15)]
- 首先把
Avesta/character12/rot090
传进partial(convert_dict, 'class')
def convert_dict(k, v):
return { k: v }
返回结果为 { 'class' :"Avesta/character12/rot090" }
- 然后把
{ 'class' :"Avesta/character12/rot090" }
传进load_class_images
def load_class_images(d): # { 'class' :"Avesta/character12/rot090" }
if d['class'] not in OMNIGLOT_CACHE:
alphabet, character, rot = d['class'].split('/')
# 值分别为 Avesta character12 rot090
image_dir = os.path.join(OMNIGLOT_DATA_DIR, 'data', alphabet, character)
# OMNIGLOT_DATA_DIR是omniglot的根路径,此句就是为了拼接出Avesta/character12的路径
class_images = sorted(glob.glob(os.path.join(image_dir, '*.png')))
# 获得路径下以png结尾的文件路径,然后排序,这是一个列表
if len(class_images) == 0:
raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(d['class'], image_dir))
image_ds = TransformDataset(ListDataset(class_images), # 这个同上文讲过的,不再赘述
compose([partial(convert_dict, 'file_name'),
partial(load_image_path, 'file_name', 'data'),
partial(rotate_image, 'data', float(rot[3:])),
partial(scale_image, 'data', 28, 28),
partial(convert_tensor, 'data')]))
loader = torch.utils.data.DataLoader(image_ds, batch_size=len(image_ds), shuffle=False)
# 将全部数据封装成一个batch,作为一个episode中的一个类
for sample in loader:
OMNIGLOT_CACHE[d['class']] = sample['data']
break # only need one sample because batch size equal to dataset length
return { 'class': d['class'], 'data': OMNIGLOT_CACHE[d['class']] }
返回结果为
{ 'class' :"Avesta/character12/rot090" ,
'data': size为(20,1,28,28)的一个tensor
}
注意(20,1,28,28)
是固定的,因为一个文件夹下面只有20张28*28的黑白图片
- 然后把上面的返回结果传进
extract_episode
,注意是 10 w a y − 5 s h o t − 15 q u e r y 10way - 5shot - 15query 10way−5shot−15query进行举例,所以每个类有5个作为support
,15个作为query
def extract_episode(n_support, n_query, d):
# data: N x C x H x W
n_examples = d['data'].size(0) # 20
if n_query == -1:
n_query = n_examples - n_support
example_inds = torch.randperm(n_examples)[:(n_support+n_query)]
# 从20个样本中 选取5+15个
support_inds = example_inds[:n_support]
# 从中选5个作为support
query_inds = example_inds[n_support:]
# 剩下的20-5个作为query
# 根据下标加载数据
xs = d['data'][support_inds]
xq = d['data'][query_inds]
return {
'class': d['class'],
'xs': xs,
'xq': xq
}
返回结果为:
{
'class': "Avesta/character12/rot090" ,
'xs': size为(5,1,28,28)的一个tensor ,
'xq': size为(15,1,28,28)的一个tensor
}
总而言之
line38的流程为:
line 107的transform的流程为
simpler
根据EpisodicBatchSampler的定义和调用,可以根据需求生成类的下标
class EpisodicBatchSampler(object):
def __init__(self, n_classes, n_way, n_episodes):
self.n_classes = n_classes
self.n_way = n_way
self.n_episodes = n_episodes
def __len__(self):
return self.n_episodes
def __iter__(self):
for i in range(self.n_episodes):
yield torch.randperm(self.n_classes)[:self.n_way]
sampler = EpisodicBatchSampler(len(ds), n_way, n_episodes)
此时len(ds)
等于train.txt
的行数,也就是4112,n_way
等于10, n_episodes
等于100,就是一个epoch
里有100个episode
,每个episode
有10个类。从EpisodicBatchSampler的定义的定义,我们可以发现,官方代码认为字符和转换某个角度(90、180、270)后的字符,是不一样的。可以看做是一种数组增强吧🤡。
另外sampler
的代码显示它返回的是list,有10个下标。
dataloader
然后dataloader
根据它提供的10个下标,去dataset
找对应下标的数据:
{
'class': "Avesta/character12/rot090" ,
'xs': size为(5,1,28,28)的一个tensor ,
'xq': size为(15,1,28,28)的一个tensor
}
然后dataloader
将10个数据合并成一个episode
在line35使用了dataloader
for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
其中state['loader']
定义在line39,tqdm
相当于有进度条的for
循环
因此,功能上可以看作
for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
...
# 等价于
for sample in train_loader:
...
调试分析sample
,可以看到和推断一致