一、引言
在学习小样本学习之元学习实现方法中,遇见MAML(模型不可知元学习)算法,通过元学习入门必备:MAML(背景+论文解读+代码分析)_元学习算法代码-CSDN博客该博客的末尾转到https://zhuanlan.zhihu.com/p/343827171知乎可以看到代码的下载地址,本博客旨在分析学习该代码,记录MAML的实现过程。
其中数据集选用Omniglot数据集。
注意:请区分:
任务数task_num(batchsz)、样本类别数(1623 characters)、选取的样本类别数n_way、批次数(episodes)
二、MAML
Meta-Learning即“学习如何学习”,展示的是一种思想,例如某个学习率的大小可以不再是手动设置的超参数,而是通过模型学习得到。其中MAML指的是学习一组最优的初始化参数,该参数能够应对不同的新任务,在新任务上通过简单的几步训练即可获得很好的效果,因此MAML的核心就是通过某种训练方法,让模型的权重和偏置参数处于一种很容易在新任务上收敛的位置。
2.1 元学习数据集
MAML是元学习的一种,因此数据处理和元学习的数据处理是一致的。
元学习的数据集包括Training_Data训练资料和Testing_Data测试资料,其中Training_Data里面又包含Support_Set支持集和Query_Set查询集,Testing_Data也包含Support_Set支持集和Query_Set查询集。
其中Spt用来获取梯度信息,Qry用来计算损失。
Training_Data是为了让模型收敛在某个好的位置,但并不是在某个任务上(例如目标分类)准确率最高,而是在新的任务上(Testing_Data)能够很快的得到很好的效果。
浅绿色线和深绿色线表示两个task的Loss随着Model Parameter变化的曲线。
左图MAML力求通过Training_Data收敛到到phi所在的位置,使得其在Testing_Data中能够很快的通过spt收敛到theta1和theta2所在位置。其中,n表示第n个任务,
表示第n个任务的参数,其中第n个任务的参数是由第n个任务的Spt得来,损失由Qry计算。区别于
表示的是第n个任务在同一个参数phi。
与Pre-training的区别在于,其力求在每个task上表现综合最优,因此模型容易收敛在由于phi位置,使得l1+l2和最小。
2.2 数据集处理
Omniglot数据集包含了1623组手写characters,每组character包含20张105*105的0单通道图像,其中同一个character的20张图像是20个不同的人手写的同一个字。因为其有大量的类别且每个类别的图像数量较少,因此该数据集适合小样本模型训练。
针对小样本学习,数据集batch格式通常为(n_way, k_shot, k_query),分别表示该batch的类别数、支持集样本数、查询集样本数。
三、代码逐段分析
下载解压后如图所示的文件结构,如果没有数据集,可以后边运行代码自动下载。
3.1 omniglot_train.py(onniglot训练、定义参数的主函数)
代码入口:
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--epoch', type=int, help='epoch number', default=40000)
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument('--imgsz', type=int, help='imgsz', default=28)
argparser.add_argument('--imgc', type=int, help='imgc', default=1)
argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=32)
argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.4)
argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
args = argparser.parse_args()
运行该文件代码从此处开始执行。
task_num表示一个batch中包含task_num个task,每个task内包含n_way个类别,每个类别包含k_spt个支持即样本、k_qry个查询集样本。
meta_lr为模型参数最终(实际)更新的(outer)lr,update_lr为模型参数在每个任务训练时的更新(inner)lr。
update_step为模型通过Training_Data训练时,模型更新的次数,也就是inner loop中,每个task喂给模型喂几次。
update_step_test就是对应Testing_Data时相同任务喂该数据的次数。
进入main主函数
config = [
('conv2d', [64, 1, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 3, 3, 2, 0]),
('relu', [True]),
('bn', [64]),
('conv2d', [64, 64, 2, 2, 1, 0]),
('relu', [True]),
('bn', [64]),
('flatten', []),
('linear', [args.n_way, 64])
]
device = torch.device('cuda')
maml = Meta(args, config).to(device)
tmp = filter(lambda x: x.requires_grad, maml.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print(maml)
print('Total trainable tensors:', num)
config定义了一个网络结构,('name', [out_ch, in_ch, kernelsz, kernelsz, stride, padding])。
其中
maml = Meta(args, config).to(device)
使得模型进入meta.py
3.2 meta.py
maml是Meta()的实例化。
class Meta(nn.Module):
"""
Meta Learner
"""
def __init__(self, args, config):
"""
:param args:
"""
super(Meta, self).__init__()
self.update_lr = args.update_lr
self.meta_lr = args.meta_lr
self.n_way = args.n_way
self.k_spt = args.k_spt
self.k_qry = args.k_qry
self.task_num = args.task_num
self.update_step = args.update_step
self.update_step_test = args.update_step_test
self.net = Learner(config, args.imgc, args.imgsz)
self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)
其中
self.net = Learner(config, args.imgc, args.imgsz)
将config再次转到Learner类中,进入Learner()。
3.3 Learner.py(构建网络结构并实现前向传播)
class Learner(nn.Module):
"""
"""
def __init__(self, config, imgc, imgsz):
"""
:param config: network config file, type:list of (string, list)
:param imgc: 1 or 3
:param imgsz: 28 or 84
"""
super(Learner, self).__init__()
self.config = config
# this dict contains all tensors needed to be optimized
self.vars = nn.ParameterList() # vars即为所有的可训练参数w
# running_mean and running_var
self.vars_bn = nn.ParameterList() # vars_bn即为所有的可训练参数bn
for i, (name, param) in enumerate(self.config):
if name is 'conv2d':
# [ch_out, ch_in, kernelsz, kernelsz] # [输出通道数, 输入通道数, 卷积核大小x, 卷积核大小y]
w = nn.Parameter(torch.ones(*param[:4])) # python的解包操作,将param的前四个参数作为一个整体传入
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w) # 用一个kaiming_normal的初始化方式对w进行初始化
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name is 'convt2d':
# [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
w = nn.Parameter(torch.ones(*param[:4]))
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_in, ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[1])))
elif name is 'linear':
# [ch_out, ch_in]
w = nn.Parameter(torch.ones(*param))
# gain=1 according to cbfinn's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name is 'bn':
# [ch_out]
w = nn.Parameter(torch.ones(param[0]))
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
# must set requires_grad=False
running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
self.vars_bn.extend([running_mean, running_var])
elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
'flatten', 'reshape', 'leakyrelu', 'sigmoid']:
continue
else:
raise NotImplementedError
在__init__()中,第一次见这种先把模型的所有可训练参数放在一对列表中。好处是自己可以决定每个参数的初始化方法,本代码中使用的是torch.nn.init.kaiming_normal_(w)初始化方式。
self.vars = nn.ParameterList() # vars即为所有的可训练参数w self.vars_bn = nn.ParameterList() # vars_bn即为所有的可训练参数bn
其中
w = nn.Parameter(torch.ones(*param[:4]))
表示定义了out_ch * in_ch * kernelsz * kernelsz个可训练的权重参数。
另外用.extend()函数可以将新列表解开追加在原列表后,区别于attend不解包作为一个整体追加。
self.vars_bn.extend([running_mean, running_var])
Learner类还重写了exter_repr()方法,便于打印信息,具体不讨论。
def extra_repr(self):
重点是forword()方法:
def forward(self, x, vars=None, bn_training=True):
"""
This function can be called by finetunning, however, in finetunning, we dont wish to update
running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
:param x: [b, 1, 28, 28]
:param vars:
:param bn_training: set False to not update
:return: x, loss, likelihood, kld
"""
if vars is None:
vars = self.vars
idx = 0
bn_idx = 0
for name, param in self.config:
if name is 'conv2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name is 'convt2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name is 'linear':
w, b = vars[idx], vars[idx + 1]
x = F.linear(x, w, b)
idx += 2
# print('forward:', idx, x.norm().item())
elif name is 'bn':
w, b = vars[idx], vars[idx + 1]
running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
idx += 2
bn_idx += 2
elif name is 'flatten':
# print(x.shape)
x = x.view(x.size(0), -1)
elif name is 'reshape':
# [b, 8] => [b, 2, 2, 2]
x = x.view(x.size(0), *param)
elif name is 'relu':
x = F.relu(x, inplace=param[0])
elif name is 'leakyrelu':
x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
elif name is 'tanh':
x = F.tanh(x)
elif name is 'sigmoid':
x = torch.sigmoid(x)
elif name is 'upsample':
x = F.upsample_nearest(x, scale_factor=param[0])
elif name is 'max_pool2d':
x = F.max_pool2d(x, param[0], param[1], param[2])
elif name is 'avg_pool2d':
x = F.avg_pool2d(x, param[0], param[1], param[2])
else:
raise NotImplementedError
# make sure variable is used properly
assert idx == len(vars)
assert bn_idx == len(self.vars_bn)
return x
在建立网络的同时,完成了输入x在初始化后的网络中前向传播的过程。
另外类还重写了梯度清零zero_grad()方法,暂时不管。
def zero_grad(self, vars=None):
……
现在回到上一次Learner类实例化的位置meta.py中。
3.4 meta.py
定义了一个将梯度通过L2范数(欧几里得范数)裁剪,防止梯度爆炸的函数。
其中
def clip_grad_by_norm_(self, grad, max_norm):
重点是前向传播函数,与后向传播训练更新参数息息相关,由于目前在Omniglot_tain.py中仅仅实例化maml = Meta(args, config).to(device),前向传播将与maml,目前还未执行到xxx=maml(……),所以继续返回到Omniglot_tain.py函数中,往下执行。
3.5 Omniglot_tain.py
# db_train即训练database 是专门处理数据集的OmniglotNShot类的实例化
db_train = OmniglotNShot('omniglot',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=args.imgsz)
进入OmniglotNshot.py函数。
3.6 OmniglotNshot.py(从本地数据集中构建加载数据集,好难懂!)
class OmniglotNShot:
def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz):
"""
Different from mnistNShot, the
:param root:
:param batchsz: task num
:param n_way:
:param k_shot:
:param k_qry:
:param imgsz:
"""
self.resize = imgsz
if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
# if root/data.npy does not exist, just download it
self.x = Omniglot(root, download=True,
transform=transforms.Compose([lambda x: Image.open(x).convert('L'),
lambda x: x.resize((imgsz, imgsz)),
lambda x: np.reshape(x, (imgsz, imgsz, 1)),
lambda x: np.transpose(x, [2, 0, 1]),
lambda x: x/255.])
)
temp = dict() # {label1:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label}
# self.x实际上是一个实例,当出现for循环迭代对象时,如果对象没有实现 __iter__ __next__ 迭代器协议,
# Python的解释器就会去寻找__getitem__ 来迭代对象,如果连__getitem__ 都没有定义,
# 这解释器就会报对象不是迭代器的错误。
# 因此此处self.x通过__getitem__()不断迭代返回(img, target) index从0到最大,超过后自动跳出循环
for (img, label) in self.x:
if label in temp.keys():
temp[label].append(img)
else:
temp[label] = [img]
self.x = []
for label, imgs in temp.items(): # labels info deserted , each label contains 20imgs
self.x.append(np.array(imgs))
# as different class may have different number of imgs
self.x = np.array(self.x).astype(np.float64) # [[20 imgs],..., 1623 classes in total]
# each character contains 20 imgs
print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] # 1623个类别,每个类别20张图
temp = [] # Free memory
# save all dataset into npy file.
np.save(os.path.join(root, 'omniglot.npy'), self.x)
print('write into omniglot.npy.')
else:
# if data.npy exists, just load it.
self.x = np.load(os.path.join(root, 'omniglot.npy'))
print('load from omniglot.npy.')
# [1623, 20, 84, 84, 1]
# TODO: can not shuffle here, we must keep training and test set distinct!
self.x_train, self.x_test = self.x[:1200], self.x[1200:] # 0~1199个类别作为训练集,1200~1622个类别为测试集, 训练集再分为支持集和查询集
# self.normalization()
self.batchsz = batchsz
self.n_cls = self.x.shape[0] # 1623
self.n_way = n_way # n way
self.k_shot = k_shot # k shot
self.k_query = k_query # k query
assert (k_shot + k_query) <=20
# save pointer of current read batch in total cache
self.indexes = {"train": 0, "test": 0}
self.datasets = {"train": self.x_train, "test": self.x_test} # original data cached
print("DB: train", self.x_train.shape, "test", self.x_test.shape)
# self.datasets["train"]就是self.x_train就是self.x的前1200个类别 # [1200, 20, 84, 84, 1]
# self.datasets["test"]就是self.x_test就是self.x的后423的类别 # [423, 20, 84, 84, 1]
# datasets_cache是一个包含着train和test两个名称的字典,分别作为当前epoch的数据集的key键值
self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), # current epoch data cached
"test": self.load_data_cache(self.datasets["test"])}
前半部分中,通过if判断本地是否存在omniglot.npy(是numpy格式的数据)文件来决定是否进行下载。第一次仅下载代码后,本地不存在该文件,接着由于
self.x = Omniglot(root, download=True, transform=transforms.Compose([lambda x: Image.open(x).convert('L'), lambda x: x.resize((imgsz, imgsz)), lambda x: np.reshape(x, (imgsz, imgsz, 1)), lambda x: np.transpose(x, [2, 0, 1]), lambda x: x/255.]) )
进入Omniglot类。
3.7 Omniglot.py(继承data.Dataset,可以通过__getitem__()魔法函数迭代返回一张张数据,主要用于下载数据集和预处理数据)
class Omniglot(data.Dataset):
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
'''
The items are (filename,category). The index of all the categories can be found in self.idx_classes
Args:
- root: the directory where the dataset will be stored
- transform: how to transform the input
- target_transform: how to transform the target
- download: need to download the dataset
'''
def __init__(self, root, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
if not self._check_exists():
if download:
self.download()
else:
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
self.idx_classes = index_classes(self.all_items)
检查路径是否存在,不存在则执行下面下载代码。暂不展示。
第一个重点是下面的两个方法。
self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) self.idx_classes = index_classes(self.all_items)
def find_classes(root_dir):
retour = []
for (root, dirs, files) in os.walk(root_dir):
for f in files:
if (f.endswith("png")):
r = root.split('/')
lr = len(r)
retour.append((f, r[lr - 2] + "/" + r[lr - 1], root))
print("== Found %d items " % len(retour))
return retour
# 为每个类别创建一个索引值
# {'classA': 0, 'classB': 1, 'classC': 2}
def index_classes(items):
idx = {}
for i in items:
if i[1] not in idx:
idx[i[1]] = len(idx)
print("== Found %d classes" % len(idx))
return idx
find_classes()中,返回一个列表,列表中存放着每一张图片的(图片文件名, 图片类别, 图片在文件系统所在的实际目录路径)。
os.walk() 是 Python 中的一个非常方便的函数,用于遍历目录和其子目录,并返回一个三元组(root, dirs, files)的生成器,按照目录树一层一层递进迭代,其中root为当前访问的路径(字符串)、dirs是一个列表,表示当前路径下的子文件夹的名字(不包含路径仅包含名字)、files是一个列表,表示当前路径下的文件名字(不包含路径仅包含名字)。
index_classes()中,i[1]是图片所属的类别。通过此函数,生成一个字典,键key为i[1]即类别名,值value为i[1]首次出现的位次。绕过来绕过去就是为每个类别赋值一个数字来代表。
另外因为继承了data.Dataset类,重写了两个方法:
# 每次仅仅返回一张图片和标签
def __getitem__(self, index):
filename = self.all_items[index][0]
img = str.join('/', [self.all_items[index][2], filename])
target = self.idx_classes[self.all_items[index][1]] # 此时target即为数字索引了
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.all_items)
这样用Omniglot类加载数据的时候,可以自动调用__getitem__()迭代返回(img, target)。
现在介绍完Omniglot.py回到OmniglotNShot.py。
3.8 OmniglotNShot.py
通过3.7的__getitem__()就可以解释3.6中的这段代码:(认真学习继承Dataset自动迭代的方法)
# self.x实际上是一个实例,当出现for循环迭代对象时,如果对象没有实现 __iter__ __next__ 迭代器协议,
# Python的解释器就会去寻找__getitem__ 来迭代对象,如果连__getitem__ 都没有定义,
# 这解释器就会报对象不是迭代器的错误。
# 因此此处self.x通过__getitem__()不断迭代返回(img, target) index从0到最大,超过后自动跳出循环
for (img, label) in self.x:
if label in temp.keys():
temp[label].append(img)
else:
temp[label] = [img]
3.6中后续代码修改了self.x的数据格式(float64的np数组),最终得到一个超大列表,列表的shape为[1623, 20, 84, 84, 1],表示1623个类别,每个类别中20个样本,每个样本1通道,像素84*84。保存为omniglot.npy。
随后定义了
self.x_train, self.x_test用来划分数据集为Training_Data(前1200个类别)和Testing_Data(后423个类别)
self.indexes = {"train": 0, "test": 0} # 表示第几个batch,后边.next()函数中逐个batch返回数据用得到
self.datasets = {"train": self.x_train, "test": self.x_test}
又定义了self.datasets_cache
# self.datasets["train"]就是self.x_train就是self.x的前1200个类别 # [1200, 20, 84, 84, 1] # self.datasets["test"]就是self.x_test就是self.x的后423的类别 # [423, 20, 84, 84, 1] # datasets_cache是一个包含着train和test两个名称的字典,分别作为当前epoch的数据集的key键值 self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]), "test": self.load_data_cache(self.datasets["test"])}
其中的self.load_data_cache()方法:
def load_data_cache(self, data_pack): # data_pack == [[20imgs] [20imgs] [……(1200or423个子列表)]]
# data_pack.shape == [1200 or 423, 20, 84, 84, 1]
# print(data_pack.shape) # (1200, 20, 1, 28, 28) or (423, 20, 1, 28, 28)
"""
Collects several batches data for N-shot learning
:param data_pack: [cls_num, 20, 84, 84, 1]
:return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
"""
# take 5 way 1 shot as example: 5 * 1
setsz = self.k_shot * self.n_way
querysz = self.k_query * self.n_way
data_cache = []
# print('preload next 50 caches of batchsz of batch.')
for sample in range(10): # num of episodes
x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
for i in range(self.batchsz): # one batch means one set
# 若 batchsz == 32(n), 则取32(n)个任务,每个任务有n_way个类别,每个类别k张spt和q张qry图
x_spt, y_spt, x_qry, y_qry = [], [], [], []
selected_cls = np.random.choice(data_pack.shape[0], self.n_way, False) # False表示不放回
#selected_cls = [12, 45, 689, 123, 456, ...(n个索引值)]
for j, cur_class in enumerate(selected_cls):
selected_img = np.random.choice(20, self.k_shot + self.k_query, False)
# 从20个数字中挑出(支持集+查询集)张数的 数字索引 个数, 前k张作为sqt, 剩下的作为qrt
# meta-training(spt) and meta-test(qry)
x_spt.append(data_pack[cur_class][selected_img[:self.k_shot]]) # 通过键值取字典
x_qry.append(data_pack[cur_class][selected_img[self.k_shot:]])
y_spt.append([j for _ in range(self.k_shot)]) # 注意: j 实际上是selected_img列表的索引值
y_qry.append([j for _ in range(self.k_query)])
'''
在这个上下文中,cur_class 是数据集中类别的实际索引,可能对于外部的标签系统来说没有实际意义。
j 是当前类别在 selected_cls 中的索引,它为模型提供了一个更简单的标签系统,
其中每个任务中的每个类别都标记为 0 到 n_way - 1 之间的整数。这简化了标签系统,
并允许模型在每个任务中更容易地识别和区分类别。
'''
# 至此 x_spt, y_spt, x_qry, y_qry 四个列表中包含了一个batch的数据,总共有batchsz个这样的数据构成一个episode,共有10个episode
# 下面将抽取的顺序打乱
# shuffle inside a batch
perm = np.random.permutation(self.n_way * self.k_shot)
x_spt = np.array(x_spt).reshape(self.n_way * self.k_shot, 1, self.resize, self.resize)[perm]
y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
perm = np.random.permutation(self.n_way * self.k_query)
x_qry = np.array(x_qry).reshape(self.n_way * self.k_query, 1, self.resize, self.resize)[perm]
y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
# append 使得x_spt==[sptsz, 1, 84, 84] => x_spts==[b, setsz, 1, 84, 84]
x_spts.append(x_spt)
y_spts.append(y_spt)
x_qrys.append(x_qry)
y_qrys.append(y_qry)
# [b, setsz, 1, 84, 84]
x_spts = np.array(x_spts).astype(np.float32).reshape(self.batchsz, setsz, 1, self.resize, self.resize)
y_spts = np.array(y_spts).astype(int).reshape(self.batchsz, setsz)
# [b, qrysz, 1, 84, 84]
x_qrys = np.array(x_qrys).astype(np.float32).reshape(self.batchsz, querysz, 1, self.resize, self.resize)
y_qrys = np.array(y_qrys).astype(int).reshape(self.batchsz, querysz)
data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
return data_cache # data_cache里装着10个列表,每个列表包含x_spts, y_spts, x_qrys, y_qrys四个子列表,
# 每个子列表包含对应子 子列表,例如x_spts内包含task_num(batchsz)个x_spt
# x_spts==[b, setsz, 1, 84, 84], setsz=n_way * k_shot
第一个for循环了10次表示10个eposides,也相当于10个batch批次。(普通深度学习中,假如有1000条数据,设置的batch_size=10则每个batch内含有100条数据,此元学习中,每个batch(task)内有n_way=5个类别,每个类别有k_shot+k_query=15张图片,所以eposide与batch个数有所区别,可以理解为抽取eposides次,不一定把所有数据抽完,会有重复且有未抽到的)。
第二个for循环了task_num次,表示取task_num个任务,每个任务有n_way个类别。
第三个for循环遍历抽出来的类别,对每个类别再抽k张spt和q张qry图。其中x_spt和x_qry列表分别存放抽到的图片的np数据,y_spt和y_qry列表分别存放抽到的图片的所属类别,值得注意的是,此处的类别并非实际1623个类别之一,而是n_way个类别之一(0 1 2 3 4),因为输出logits就是n_way维的。
如此,返回了shape为[b=10, spt_size, 1, 84, 84]的data_cache。
另外,实际训练时会通过.next()方法返回单批次数据,在这里提前展示一下。
def next(self, mode='train'): # 默认返回训练模式下sqt和qrt
"""
Gets next batch from the dataset with name.
:param mode: The name of the splitting (one of "train", "val", "test")
:return:
"""
# update cache if indexes is larger cached num
if self.indexes[mode] >= len(self.datasets_cache[mode]):
self.indexes[mode] = 0
self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
next_batch = self.datasets_cache[mode][self.indexes[mode]] # self.datasets_cache[mode].shape == [b, setsz, 1, 84, 84]
self.indexes[mode] += 1 # indexes表示第几个batch,其中1batch数据即一个episode个手机的数据,包括task_num个任务,每个任务n_way个类别
return next_batch # 返回一个batch的数据 [x_spts, y_spts, x_qrys, y_qrys]
self.index[mode]表示该mode下第几个batch。[x_spts, y_spts, x_qrys, y_qrys],其中x_spts内包含task_num个任务,每个任务n_way个类别,每个类别setsz个图像。
数据的问题解决了,现在返回到Omniglot_train.py。
3.9 Omniglot_train.py
for step in range(args.epoch):
x_spt, y_spt, x_qry, y_qry = db_train.next()
x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
# set traning=True to update running_mean, running_variance, bn_weights, bn_bias
accs = maml(x_spt, y_spt, x_qry, y_qry)
epoch循环。
每次调用next()返回一个batch的数据,注意并不是将整个data_cache列表均返回。一个data_cache内包含episodes个batch。
通过
accs = maml(x_spt, y_spt, x_qry, y_qry)
进入Meta类的前向传播计算。
3.10 meta.py 前向传播
注意:源代码中添加
y_spt = y_spt.long() y_qry = y_qry.long()
否则计算交叉熵数据类型报错。
def forward(self, x_spt, y_spt, x_qry, y_qry):
"""
:param x_spt: [b, setsz, c_, h, w]
:param y_spt: [b, setsz]
:param x_qry: [b, querysz, c_, h, w]
:param y_qry: [b, querysz]
:return:
"""
y_spt = y_spt.long()
y_qry = y_qry.long()
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i
corrects = [0 for _ in range(self.update_step + 1)]
for i in range(task_num):
# 1. run the i-th task and compute loss for k=0
logits = self.net(x_spt[i], vars=None, bn_training=True) # vars用于初始化网络参数, None表示用创建网络时的参数
loss = F.cross_entropy(logits, y_spt[i])
grad = torch.autograd.grad(loss, self.net.parameters())
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))
# fast_weights=parameters−learning_rate×gradients
'''
grad = [Δw1, Δw2, ..., Δwn] (模型参数的梯度列表)
self.net.parameters() = [w1, w2, ..., wn] (模型参数列表)
zip(grad, self.net.parameters())==[(Δw1, w1), (Δw2, w2), ..., (Δwn, wn)]
map函数将一个函数(lambda)应用于一个序列p的每个元素。这里,它将一个 lambda 函数应用于 zip 函数生成的元组列表。
转变成list后fast_weights包含了更新后的所有参数列表
fast_weights = [
(w1 - self.update_lr * Δw1),
(w2 - self.update_lr * Δw2),
...
(wn - self.update_lr * Δwn)
]
'''
# this is the loss and accuracy before first update
with torch.no_grad():
# [setsz, nway]
logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
loss_q = F.cross_entropy(logits_q, y_qry[i])
losses_q[0] += loss_q
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item()
corrects[0] = corrects[0] + correct
# this is the loss and accuracy after the first update
with torch.no_grad():
# [setsz, nway]
logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
loss_q = F.cross_entropy(logits_q, y_qry[i])
losses_q[1] += loss_q
# [setsz]
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item()
corrects[1] = corrects[1] + correct
for k in range(1, self.update_step):
# 1. run the i-th task and compute loss for k=1~K-1
logits = self.net(x_spt[i], fast_weights, bn_training=True)
loss = F.cross_entropy(logits, y_spt[i])
# 2. compute grad on theta_pi
grad = torch.autograd.grad(loss, fast_weights)
# 3. theta_pi = theta_pi - train_lr * grad
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = F.cross_entropy(logits_q, y_qry[i])
losses_q[k + 1] += loss_q
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy
corrects[k + 1] = corrects[k + 1] + correct
# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / task_num
# optimize theta parameters
self.meta_optim.zero_grad() # 梯度清零
loss_q.backward()
# print('meta update')
# for p in self.net.parameters()[:5]:
# print(torch.norm(p).item())
self.meta_optim.step()
accs = np.array(corrects) / (querysz * task_num)
return accs
主要看for循环中,首先将x_spts的第i个任务的数据取出来,喂给net,计算了原始var下的logits,与y_spts[i]计算cross_entropy,计算grad,计算fast_weight = theta - update_lr * grad,这里用到了map和lambda函数,其中p=zip(grad, self.net.parameters())。
然后with no_grad():
losses_q[0]+=将x_qry喂入原始参数时,输出结果的loss
corrects[0]+=将x_qry喂入原始参数时,输出结果的正确个数
然后with no_grad():
losses_q[1]+=将x_qry喂入通过x_spt数据更新的第一次参数时,输出结果的loss
corrects[1]+=将x_qry喂入通过x_spt数据更新的第一次参数时,输出结果的正确个数
然后for k in range(1, self.update_step):
losses_q[1+1]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的loss
corrects[1+1]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的正确个数
……
k==self.update_step-1时,
losses_q[update_step]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的loss
corrects[update_step]+=将x_qry喂入通过x_spt数据更新的第二次参数时,输出结果的正确个数
故该两个列表都是update_step+1维,依次表示原参数loss/correct_num、第一次支持集更新后查询集上的loss/correct_num、第二次支持集更新后查询集上的loss/correct_num、……、第update_step次支持集更新后查询集上的loss/correct_num。
遍历完所有任务后,由于记录数据时都是+=,所以最后统一 / task_num,得到平均的损失和正确率。
self.meta_optim.zero_grad() # 梯度清零 loss_q.backward() self.meta_optim.step()
通过这个完成一次loss_q反向传播,并更新参数。
完成task_num次循环后,回到Omniglot_train.py。
3.11 Omniglot_train.py
if step % 50 == 0:
print('step:', step, '\ttraining acc:', accs)
if step % 500 == 0:
accs = []
for _ in range(1000//args.task_num):
# test
x_spt, y_spt, x_qry, y_qry = db_train.next('test')
x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
# split to single task each time
for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
accs.append( test_acc )
# [b, update_step+1]
accs = np.array(accs).mean(axis=0).astype(np.float16)
print('Test acc:', accs)
每50个epoch打印一下acc,主要看第一个和最后一个值,分别代表初始参数下正确率和当前更新后参数的查询集准确率。
每500个epoch,进行一次Testing_Data测试。其中
# split to single task each time for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry): test_acc = maml.finetunning(x_spt_one, y_spt_one, x_qry_one, y_qry_one) accs.append( test_acc )
将数据分离,因此测试任务每次应该只有一个任务,n_way个类别。
回到meta.py的finetunning()方法。
3.12 meta.py(Testing_Data微调)
def finetunning(self, x_spt, y_spt, x_qry, y_qry):
"""
:param x_spt: [setsz, c_, h, w]
:param y_spt: [setsz]
:param x_qry: [querysz, c_, h, w]
:param y_qry: [querysz]
:return:
"""
y_spt = y_spt.long()
y_qry = y_qry.long()
assert len(x_spt.shape) == 4
querysz = x_qry.size(0)
corrects = [0 for _ in range(self.update_step_test + 1)]
# in order to not ruin the state of running_mean/variance and bn_weight/bias
# we finetunning on the copied model instead of self.net
net = deepcopy(self.net)
# 1. run the i-th task and compute loss for k=0
logits = net(x_spt)
loss = F.cross_entropy(logits, y_spt)
grad = torch.autograd.grad(loss, net.parameters())
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())))
# this is the loss and accuracy before first update
with torch.no_grad():
# [setsz, nway]
logits_q = net(x_qry, net.parameters(), bn_training=True)
# [setsz]
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
# scalar
correct = torch.eq(pred_q, y_qry).sum().item()
corrects[0] = corrects[0] + correct
# this is the loss and accuracy after the first update
with torch.no_grad():
# [setsz, nway]
logits_q = net(x_qry, fast_weights, bn_training=True)
# [setsz]
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
# scalar
correct = torch.eq(pred_q, y_qry).sum().item()
corrects[1] = corrects[1] + correct
for k in range(1, self.update_step_test):
# 1. run the i-th task and compute loss for k=1~K-1
logits = net(x_spt, fast_weights, bn_training=True)
loss = F.cross_entropy(logits, y_spt)
# 2. compute grad on theta_pi
grad = torch.autograd.grad(loss, fast_weights)
# 3. theta_pi = theta_pi - train_lr * grad
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
logits_q = net(x_qry, fast_weights, bn_training=True)
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = F.cross_entropy(logits_q, y_qry)
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry).sum().item() # convert to numpy
corrects[k + 1] = corrects[k + 1] + correct
del net
accs = np.array(corrects) / querysz
return accs
微调和训练过程十分类似,区别在于微调仅输入单任务,并且初始参数是训练后所认为的最优参数。
由于微调是在原训练模型的基础上,并且要保证原最优参数不变,故需要通过
net = deepcopy(self.net)
复制一份模型。
另外还有就是update_step_test不同,一般大于update_step。因此一般新任务多训练几次效果会更好,只不过在找最优初始参数的训练时为了防止数据量小过拟合、提高训练速度而减少更新次数。
返回值accs列表中最后一个值表示模型参数通过update_step_test次该任务支持集的更新后,在该任务上的查询集上的准确率。
四、总结
经过5个py文件的12次相互调用后,该代码运行并解析完毕!
配置模型(独特的参数初始化方法)——下载数据——加载数据(加载数据较为复杂难懂,关键区分epoch、eposides、task_num=batchsz、n_way几个参数的实际意义)——训练(支持集更新、查询集计算loss和correct)——测试(单任务支持集更新、查询集计算loss和correct)。