技术原理(数学公式)
核心思想:Model-Agnostic Meta-Learning (MAML) 通过优化模型参数的初始化点,使得模型在少量新任务数据上经过少量梯度更新即可快速适应。
目标函数:
min
θ
∑
T
i
∼
p
(
T
)
L
T
i
(
f
θ
−
α
∇
θ
L
T
i
(
f
θ
)
)
\min_{\theta} \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_i}\left(f_{\theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta)}\right)
θminTi∼p(T)∑LTi(fθ−α∇θLTi(fθ))
参数更新规则(两次梯度更新):
- 内循环(Task-specific adaptation):
θ ′ = θ − α ∇ θ L T i ( f θ ) \theta' = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta) θ′=θ−α∇θLTi(fθ) - 外循环(Meta-update):
θ ← θ − β ∇ θ ∑ T i L T i ( f θ ′ ) \theta \gets \theta - \beta \nabla_\theta \sum_{\mathcal{T}_i} \mathcal{L}_{\mathcal{T}_i}(f_{\theta'}) θ←θ−β∇θTi∑LTi(fθ′)
实现方法(PyTorch代码)
import torch
import torch.nn as nn
class MAML(nn.Module):
def __init__(self, input_dim, hidden_size=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1)
)
def forward(self, x, params=None):
if params is None:
return self.net(x)
return self._apply_weights(x, params)
def _apply_weights(self, x, params):
x = F.relu(torch.matmul(x, params['0.weight']) + params['0.bias'])
return torch.matmul(x, params['1.weight']) + params['1.bias']
def maml_train(model, tasks, inner_lr=0.1, meta_lr=0.001):
optimizer = torch.optim.Adam(model.parameters(), lr=meta_lr)
for _ in range(meta_epochs):
for task_batch in tasks:
loss_sum = 0
for task in task_batch:
# 内循环适配
fast_weights = dict(model.named_parameters())
for _ in range(num_updates):
loss = compute_loss(task, model, fast_weights)
grad = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights = {name: param - inner_lr * g for (name, param), g in zip(fast_weights.items(), grad)}
loss_sum += compute_loss(task, model, fast_weights)
# 外循环更新
loss_sum.backward()
optimizer.step()
应用案例(小样本图像分类)
- 场景:5-way 1-shot图像分类(如Omniglot数据集)。
- 效果:在5-way 1-shot任务中,MAML在5次梯度更新后达到92.3%的准确率,显著优于传统迁移学习(如预训练+微调)。
优化技巧
-
超参数调优
- 内循环步数:通常1-5步,步长( α \alpha α)设为0.01~0.1。
- 外循环学习率( β \beta β)设为0.001~0.01。
- 二阶导数近似:Hessian-Free优化,减少内存开销。
-
工程实践
# 梯度累积(防止内存溢出) torch.cuda.empty_cache() # 混合精度训练 scaler = torch.cuda.amp.GradScaler()
前沿进展
- ANIL (Almost No Inner Loop)
简化内循环,仅更新最后一层,减少计算量(ICLR 2020)。for name, param in model.named_parameters(): if 'last_layer' not in name: param.requires_grad = False
- Meta-Dataset
跨领域元学习基准,涵盖ImageNet+TrafficSign等10个领域。
优化技巧(代码示例)
# 二阶导数近似(Hessian-Free)
with torch.backends.cudnn.flags(enabled=False):
grad = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
工程技巧:使用@torch.enable_grad()
强制跟踪二阶导。
效果指标
方法 | 5-way 1-shot (Omniglot) | 收敛步数 |
---|---|---|
MAML | 98.2% | 1000 |
Reptile | 95.7% | 500 |
ProtoNet | 92.3% | 300 |
前沿进展
- Meta-Dataset
多领域元学习基准(Google Research,2023),支持跨域任务泛化。 - Meta-Learned Loss
(ICML 2023)通过元学习损失函数,提升少样本学习泛化性。
# 开源实现(参考)
!git clone https://github.com/cbfinn/maml
总结
MAML通过优化初始化参数,实现快速任务适应,在少样本学习领域具有里程碑意义。未来方向包括:动态架构搜索、任务无关初始化(如ProtoMAML)、跨模态适应等。
扩展阅读:
通过优化初始化策略,MAML及其变体(如ProtoMAML)正推动少样本学习在医疗影像、机器人控制等场景落地。