数据流
one task: c way k shot [支持图像(每类k个样本)和查询图像(c个图像)];
batch: 基于给定支持集,采样batch组查询图像,构成一个episode ,同一个episode中的不同查询图像共享支持样本;
迭代次数:episode的个数;
eg:
(1)5 way 1 shot,b=10
支持集(5x1,c,w,h,)
查询集(5x10,c,w,h)
(2)20 way 5 shot,b=10
支持集(20x5,c,w,h,)
查询集(20x10,c,w,h)
模型
特征提取模块
常规小样本图像分类 4个卷积层 提取网络
关系度量模块
两个卷积层和两个全连接
损失
MSE
技巧
网络权重自定义的初始化方式
# 定义初始化函数
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif classname.find('Linear') != -1:
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data = torch.ones(m.bias.data.size())
# 调用初始化方法
feature_encoder = CNNEncoder()
feature_encoder.apply(weights_init)
张量的变形操作
unsqueeze()
squeeze()
repeat()
transpose()
view()
cat()