数据集
Omniglot
- 包含50个字母表的1623个手写字符,每个字符包含20个样本
- 先调整尺寸到28x28,之后通过多次旋转90度的方式增加字符的种类,一共6492类
- 划分
- 训练集:82240项 4112类
- 验证集:13760项 688类
- 测试集:33840项 1692类
Mini-ImageNet
- 从ImageNet中随机选取100个类,每类包含600个样本
- 将尺寸缩放到84x84
- 包含
- 训练集:64类
- 验证集:16类
- 测试集:20类
数据准备
每个iteration包含多个batch,也就是多个eposide;每个eposide包含随机的classes_per_it个类别,每个类别包含随机选择的sample_per_class个样本组成support set,query set由这些类中的一个随机类的一个随机样本组成。由于这些样本是作为一个序列输入到模型中的,所以最后一个样本即为query set,也就是要预测标签的样本。输入时,将一个batch中的所有eposide的样本拼接起来一起输入。
模型
将图像输入到时序卷积网络前,先要对图像做特征提取
特征提取
- Omniglot:使用和PrototpicalNet相同的结构
- Mini-ImageNet:在PrototpicalNet中,使用的是和Omniglot相同的结构,通道数减少到32,但是这样浅层的特征提取网络没有充分的利用SNAIL的容量,所以使用了ResNet进行特征提取
[ 84 , 84 , 3 ] → [ 42 , 42 , 64 ] → [ 21 , 21 , 96 ] → [ 10 , 10 , 128 ] → [ 5 , 5 , 256 ] → [ 5 , 5 , 2048 ] → [ 1 , 1 , 2048 ] → [ 1 , 1 , 384 ] [84,84,3]\rightarrow[42,42,64]\rightarrow[21,21,96]\rightarrow[10,10,128]\rightarrow[5,5,256]\rightarrow[5,5,2048]\rightarrow[1,1,2048]\rightarrow[1,1,384] [84,84,3]→[42,42,64]→[21,21,96]→[10,<