基于PyTorch的少样本学习(Few-shot Learning)实现:用"小抄"教会AI快速学习新任务
关键词:少样本学习、PyTorch、元学习、支持集、原型网络
摘要:传统深度学习需要"海量数据喂养",但现实中很多场景(如罕见病诊断、新物种识别)只有少量样本。本文将用"小学生考试"的比喻,带您一步步理解少样本学习(Few-shot Learning)的核心原理,并用PyTorch实现一个能"看5张图学会新类别"的原型网络(Prototypical Networks),最后结合医疗、金融等真实场景,揭示这项技术如何让AI从"数据吃货"变身"学习高手"。
背景介绍
目的和范围
想象一下:医生拿到一种从未见过的罕见病患者的5张CT图像,希望AI能快速识别同类病例;程序员需要让客服机器人识别"用户抱怨快递盒破损"这个新意图,但只有3条标注数据。这些场景中,传统深度学习(需要成千上万张标注图)完全失效,少样本学习(Few-shot Learning, FSL)正是为解决这类问题而生。本文将覆盖少样本学习的核心概念、经典算法(原型网络)、PyTorch实战代码,以及真实应用场景。
预期读者
- 对PyTorch有基础了解(会写简单的神经网络)
- 听说过深度学习但没接触过少样本学习的开发者/学生
- 想解决"样本不足"实际问题的算法工程师
文档结构概述
本文将按照"概念理解→原理拆解→代码实战→场景落地"的逻辑展开:先用"考试小抄"的故事解释少样本学习;再拆解核心概念(支持集、元学习等)并画流程图;接着用PyTorch实现原型网络,逐行解读代码;最后结合医疗/金融场景说明应用价值,推荐学习资源。
术语表
术语 | 解释(用小学生能懂的话) |
---|---|
支持集(Support Set) | 像考试前老师给的"小抄":AI学习新任务时,用来"参考"的少量样本(比如5张猫的照片) |
查询集(Query Set) | 像考试题目:AI学完"小抄"后需要识别的新样本(比如1张没见过的猫的照片) |
元学习(Meta-Learning) | "学习如何学习"的超能力:AI不是死记硬背具体知识,而是学会"如何快速掌握新知识"的方法(就像学生学会高效记笔记的方法,新科目也能快速上手) |
原型(Prototype) | 同类样本的"平均特征":比如把5张猫的照片的特征"揉成一团",得到一个代表猫的"典型特征",新照片和它越像就是猫 |
核心概念与联系
故事引入:小明的"学习超能力"
小明是个转学生,每天要转去新学校上不同的课(比如周一学火星语,周二学恐龙分类)。但每门新课他只能看到5页课本(支持集),然后就要考试(用查询集测试)。普通学生只能死记硬背这5页,遇到没见过的题目就抓瞎。但小明学会了"学习方法"(元学习):他发现所有课本的知识都能提炼成"关键词特征"(比如火星语的"尖刺发音"、恐龙的"尾巴长度"),然后把5页课本的关键词揉成"典型特征"(原型),考试时只要新题目和哪个"典型特征"像,就选哪个答案。少样本学习中的AI,就像小明这样的"学习高手"。
核心概念解释(像给小学生讲故事)
核心概念一:支持集 vs 查询集——小抄与考题
- 支持集(Support Set):AI学习新任务时的"小抄"。比如要识别新动物"霍加狓",支持集可能是5张霍加狓的照片+5张斑马的照片(假设是2类5样本的"2-way 5-shot"任务)。
- 查询集(Query Set):AI学完"小抄"后要测试的"考题"。比如10张混合了霍加狓和斑马的照片,AI需要正确分类这些"考题"。
核心概念二:元学习——学会"学习方法"
元学习(Meta-Learning)的核心是"学习如何学习"。传统深度学习像"填鸭式学生",只能记住具体知识(比如只认识常见的猫);元学习像"学习方法大师",它从大量"小任务"(比如学过的1000种动物分类任务)中总结出"如何用少量样本快速学习新任务"的方法。就像学生通过做大量试卷,总结出"先划重点再记忆"的学习方法,遇到新科目也能快速上手。
核心概念三:原型网络——找"典型特征"
原型网络(Prototypical Networks)是少样本学习的经典算法。它的思路很简单:给每个类别计算一个"典型特征"(原型),新样本和哪个原型最像,就属于哪个类别。比如支持集中有5张猫的照片,先把每张照片用神经网络转换成特征向量(像给照片做"特征指纹"),然后把这5个指纹取平均,得到猫的"原型指纹"。新照片的指纹和猫的原型指纹距离近,就判断为猫。
核心概念之间的关系(用小学生能理解的比喻)
- 支持集+查询集 = 练习册的"例题+习题":支持集是例题(小抄),查询集是习题(考题),AI通过例题学会方法,再用习题测试是否学会。
- 元学习+原型网络 = 学习方法+具体技巧:元学习教会AI"如何用少量样本学习"(比如总结特征的方法),原型网络是其中一种具体技巧(计算原型特征)。
- 支持集+原型 = 小抄+重点提炼:支持集是原始小抄,原型是从小抄中提炼的"重点"(平均特征),AI通过重点快速匹配新样本。
核心概念原理和架构的文本示意图
少样本学习的核心流程:
- 元训练阶段:用大量"历史任务"(每个任务包含支持集+查询集)训练模型,让模型学会"如何用支持集快速生成原型,并匹配查询集"。
- 元测试阶段:遇到新任务时,用新任务的支持集生成原型,用原型对查询集分类。
Mermaid 流程图
graph TD
A[元训练阶段] --> B[历史任务1: 支持集S1+查询集Q1]
A --> C[历史任务2: 支持集S2+查询集Q2]
A --> D[...多个历史任务...]
B --> E[提取S1中每个样本的特征]
C --> E
D --> E
E --> F[计算每个类别的原型(特征平均)]
F --> G[用原型对Q1/Q2/...分类,计算损失并更新模型]
G --> H[模型学会"如何生成原型并匹配查询集"]
H --> I[元测试阶段]
I --> J[新任务: 支持集S_new+查询集Q_new]
J --> K[提取S_new特征,生成新原型]
K --> L[用新原型对Q_new分类]
核心算法原理 & 具体操作步骤(以原型网络为例)
原型网络的核心是"特征提取→计算原型→距离匹配"三步骤,我们用PyTorch实现这一过程。
步骤1:特征提取(用神经网络做"特征翻译机")
需要一个神经网络(称为"嵌入网络",Encoder),把图像(或其他数据)转换成固定长度的特征向量。就像把中文翻译成英文,嵌入网络把"像素矩阵"翻译成"特征语言",让计算机能理解图像的"内在含义"。
步骤2:计算原型(给每个类别做"特征平均值")
假设支持集有N个类别(N-way),每个类别有K个样本(K-shot),总共有N×K个样本。对每个类别,取该类别K个样本的特征向量的平均值,得到该类的原型(Prototype)。数学上表示为:
p
c
=
1
K
∑
i
=
1
K
f
(
x
c
,
i
)
p_c = \frac{1}{K} \sum_{i=1}^K f(x_{c,i})
pc=K1i=1∑Kf(xc,i)
其中,
f
(
x
)
f(x)
f(x)是嵌入网络的输出(特征向量),
p
c
p_c
pc是类别c的原型。
步骤3:距离匹配(用"相似度"判断类别)
对于查询集中的每个样本
x
q
x_q
xq,计算它的特征
f
(
x
q
)
f(x_q)
f(xq)与所有类别原型
p
c
p_c
pc的距离(常用欧氏距离或余弦相似度)。距离最近的类别即为预测结果。欧氏距离公式:
d
(
f
(
x
q
)
,
p
c
)
=
∑
i
=
1
D
(
f
(
x
q
)
i
−
p
c
i
)
2
d(f(x_q), p_c) = \sqrt{\sum_{i=1}^D (f(x_q)_i - p_c^i)^2}
d(f(xq),pc)=i=1∑D(f(xq)i−pci)2
其中D是特征向量的维度(比如64维)。
PyTorch代码框架(核心部分)
import torch
import torch.nn as nn
import torch.nn.functional as F
# 步骤1:定义嵌入网络(简单CNN为例)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 3), # 输入1通道(灰度图),输出64通道
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2), # 下采样,缩小图像尺寸
# 重复3次卷积块,最终输出64维特征
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.fc = nn.Linear(64, 64) # 全连接层输出64维特征
def forward(self, x):
x = self.conv(x) # 卷积后得到 (batch, 64, 1, 1)
x = x.view(x.size(0), -1) # 展平为 (batch, 64)
x = self.fc(x) # 输出64维特征向量
return x
# 步骤2:计算原型(在训练/测试时动态计算)
def compute_prototypes(support_features, support_labels, num_classes):
# support_features: (N*K, 64),N是类别数,K是样本数
# support_labels: (N*K,),类别标签(0到N-1)
prototypes = []
for c in range(num_classes):
# 选出属于类别c的所有样本的特征
mask = (support_labels == c)
class_features = support_features[mask]
# 计算平均特征(原型)
prototype = class_features.mean(dim=0)
prototypes.append(prototype)
return torch.stack(prototypes) # (N, 64)
# 步骤3:计算查询样本与原型的距离(欧氏距离)
def euclidean_distance(query_features, prototypes):
# query_features: (Q, 64),Q是查询样本数
# prototypes: (N, 64)
# 计算每个查询样本与所有原型的距离
# 公式:(a - b)^2 = a² + b² - 2ab
a_sq = query_features.pow(2).sum(dim=1, keepdim=True) # (Q, 1)
b_sq = prototypes.pow(2).sum(dim=1) # (N,)
ab = query_features @ prototypes.T # (Q, N)
distances = a_sq + b_sq - 2 * ab # (Q, N)
return distances
数学模型和公式 & 详细讲解 & 举例说明
损失函数:交叉熵损失(用"猜答案的准确度"作为学习目标)
原型网络的训练目标是:让查询样本的预测类别尽可能接近真实标签。预测概率通过"负距离"的softmax计算(距离越近,概率越高):
p
(
c
∣
x
q
)
=
exp
(
−
d
(
f
(
x
q
)
,
p
c
)
)
∑
c
′
=
1
N
exp
(
−
d
(
f
(
x
q
)
,
p
c
′
)
)
p(c|x_q) = \frac{\exp(-d(f(x_q), p_c))}{\sum_{c'=1}^N \exp(-d(f(x_q), p_{c'}))}
p(c∣xq)=∑c′=1Nexp(−d(f(xq),pc′))exp(−d(f(xq),pc))
损失函数使用交叉熵:
L
=
−
1
Q
∑
q
=
1
Q
log
(
p
(
y
q
∣
x
q
)
)
\mathcal{L} = -\frac{1}{Q} \sum_{q=1}^Q \log(p(y_q|x_q))
L=−Q1q=1∑Qlog(p(yq∣xq))
其中
y
q
y_q
yq是查询样本
x
q
x_q
xq的真实标签,
Q
Q
Q是查询集样本数。
举例说明:2-way 1-shot任务(识别猫和狗,各1张支持图)
- 支持集:猫的照片 x c a t x_{cat} xcat,狗的照片 x d o g x_{dog} xdog。
- 嵌入网络输出: f ( x c a t ) = [ 0.2 , 0.5 , 0.3 ] f(x_{cat}) = [0.2, 0.5, 0.3] f(xcat)=[0.2,0.5,0.3](假设3维特征), f ( x d o g ) = [ 0.8 , 0.1 , 0.4 ] f(x_{dog}) = [0.8, 0.1, 0.4] f(xdog)=[0.8,0.1,0.4]。
- 计算原型: p c a t = f ( x c a t ) = [ 0.2 , 0.5 , 0.3 ] p_{cat} = f(x_{cat}) = [0.2, 0.5, 0.3] pcat=f(xcat)=[0.2,0.5,0.3](因为1-shot,平均就是自己), p d o g = [ 0.8 , 0.1 , 0.4 ] p_{dog} = [0.8, 0.1, 0.4] pdog=[0.8,0.1,0.4]。
- 查询样本:一张新照片 x q x_q xq,特征 f ( x q ) = [ 0.3 , 0.4 , 0.3 ] f(x_q) = [0.3, 0.4, 0.3] f(xq)=[0.3,0.4,0.3]。
- 计算欧氏距离:
- d ( x q , p c a t ) = ( 0.3 − 0.2 ) 2 + ( 0.4 − 0.5 ) 2 + ( 0.3 − 0.3 ) 2 = 0.01 + 0.01 + 0 = 0.141 d(x_q, p_{cat}) = \sqrt{(0.3-0.2)^2 + (0.4-0.5)^2 + (0.3-0.3)^2} = \sqrt{0.01 + 0.01 + 0} = 0.141 d(xq,pcat)=(0.3−0.2)2+(0.4−0.5)2+(0.3−0.3)2=0.01+0.01+0=0.141
- d ( x q , p d o g ) = ( 0.3 − 0.8 ) 2 + ( 0.4 − 0.1 ) 2 + ( 0.3 − 0.4 ) 2 = 0.25 + 0.09 + 0.01 = 0.591 d(x_q, p_{dog}) = \sqrt{(0.3-0.8)^2 + (0.4-0.1)^2 + (0.3-0.4)^2} = \sqrt{0.25 + 0.09 + 0.01} = 0.591 d(xq,pdog)=(0.3−0.8)2+(0.4−0.1)2+(0.3−0.4)2=0.25+0.09+0.01=0.591
- 预测: x q x_q xq与猫的原型距离更近,所以预测为猫。
项目实战:代码实际案例和详细解释说明
开发环境搭建
- 系统:Windows/Linux/macOS(推荐Ubuntu)
- 语言:Python 3.8+
- 框架:PyTorch 1.9+(安装命令:
pip install torch torchvision
) - 数据集:Omniglot(少样本学习经典数据集,包含1623个字符类别,每个类别20张手写图片,适合小样本场景)
源代码详细实现和代码解读
我们将实现一个完整的原型网络训练流程,包含数据加载、模型定义、训练循环。
步骤1:数据加载(用Task的方式组织支持集和查询集)
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class OmniglotDataset(Dataset):
def __init__(self, root, split='train', n_way=5, k_shot=1, q_query=15):
self.root = os.path.join(root, split)
self.n_way = n_way # 每个任务的类别数(N-way)
self.k_shot = k_shot # 每个类别的支持样本数(K-shot)
self.q_query = q_query # 每个类别的查询样本数
self.transform = transforms.Compose([
transforms.Resize((28, 28)), # 缩放到28x28
transforms.ToTensor(), # 转成Tensor(0-1)
transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到-1~1
])
# 收集所有字符路径(每个字符是一个类别)
self.classes = [os.path.join(self.root, family, char)
for family in os.listdir(self.root)
for char in os.listdir(os.path.join(self.root, family))]
self.classes = sorted(self.classes)
def __getitem__(self, index):
# 随机选择n_way个类别
selected_classes = np.random.choice(self.classes, self.n_way, replace=False)
support_images = []
support_labels = []
query_images = []
query_labels = []
for label, cls in enumerate(selected_classes):
# 每个类别选k_shot + q_query张图片
all_images = [os.path.join(cls, img) for img in os.listdir(cls)]
selected_images = np.random.choice(all_images, self.k_shot + self.q_query, replace=False)
# 前k_shot张作为支持集
for img_path in selected_images[:self.k_shot]:
img = Image.open(img_path).convert('L') # 转为灰度图
support_images.append(self.transform(img))
support_labels.append(label)
# 后q_query张作为查询集
for img_path in selected_images[self.k_shot:]:
img = Image.open(img_path).convert('L')
query_images.append(self.transform(img))
query_labels.append(label)
# 合并成Tensor并打乱顺序(PyTorch需要)
support_images = torch.stack(support_images) # (n_way*k_shot, 1, 28, 28)
support_labels = torch.tensor(support_labels) # (n_way*k_shot,)
query_images = torch.stack(query_images) # (n_way*q_query, 1, 28, 28)
query_labels = torch.tensor(query_labels) # (n_way*q_query,)
return support_images, support_labels, query_images, query_labels
def __len__(self):
return 1000 # 自定义epoch数(实际可根据需求调整)
步骤2:模型训练循环
def train():
# 参数设置
n_way = 5
k_shot = 1
q_query = 15
lr = 0.001
epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 初始化模型、数据加载器、优化器
encoder = Encoder().to(device)
dataset = OmniglotDataset(root='omniglot', split='train',
n_way=n_way, k_shot=k_shot, q_query=q_query)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # 每个batch是一个任务
optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)
for epoch in range(epochs):
encoder.train()
total_loss = 0
total_acc = 0
for batch in dataloader:
# 加载支持集和查询集(batch_size=1,所以取[0])
support_imgs, support_labels, query_imgs, query_labels = [x[0].to(device) for x in batch]
# 步骤1:提取支持集和查询集的特征
support_features = encoder(support_imgs) # (n_way*k_shot, 64)
query_features = encoder(query_imgs) # (n_way*q_query, 64)
# 步骤2:计算原型(n_way个类别,每个类别k_shot样本)
prototypes = compute_prototypes(support_features, support_labels, n_way) # (n_way, 64)
# 步骤3:计算查询样本与原型的距离(欧氏距离)
distances = euclidean_distance(query_features, prototypes) # (n_way*q_query, n_way)
# 步骤4:计算预测概率(负距离的softmax)
logits = -distances # 距离越近,logits越大(因为要最大化概率)
probs = F.softmax(logits, dim=1)
# 步骤5:计算交叉熵损失
loss = F.cross_entropy(logits, query_labels)
# 步骤6:反向传播更新模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
preds = probs.argmax(dim=1)
acc = (preds == query_labels).float().mean()
total_loss += loss.item()
total_acc += acc.item()
# 打印每轮训练结果
avg_loss = total_loss / len(dataloader)
avg_acc = total_acc / len(dataloader)
print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}')
if __name__ == '__main__':
train()
代码解读与分析
- 数据加载器:
OmniglotDataset
将数据组织成"任务"(每个任务包含支持集和查询集),模拟少样本场景。例如,n_way=5
表示每个任务要学5个新类别,k_shot=1
表示每个类别只有1张支持图。 - 特征提取:
Encoder
是一个简单的CNN,将28x28的灰度图转换为64维特征向量。卷积层的作用是提取图像的局部特征(如线条、弧度),全连接层将其整合为全局特征。 - 原型计算:
compute_prototypes
函数通过平均支持集特征生成原型,这是少样本学习的核心——用少量样本总结出"典型特征"。 - 训练循环:每个batch是一个独立的任务,模型通过大量任务学习"如何生成有效原型"。损失函数引导模型让查询样本的预测概率尽可能接近真实标签。
实际应用场景
场景1:医疗影像中的罕见病诊断
医院数据库中可能只有5-10张某罕见病的CT图像(如"肺淋巴管肌瘤病"),传统深度学习无法训练。少样本学习可以用这些少量样本生成该病的"特征原型",快速识别新病例。
场景2:金融风控中的新欺诈模式识别
新的欺诈手段(如"虚假海外购物退款")刚出现时,只有少量标注案例。少样本学习能基于这些案例总结欺诈特征,实时检测新的欺诈行为。
场景3:智能家居的新语音指令学习
用户可能自定义新的语音指令(如"打开儿童房氛围灯"),但只有3-5次历史录音。少样本学习可以快速学习这些新指令,无需重新训练大模型。
工具和资源推荐
类型 | 推荐资源 | 说明 |
---|---|---|
数据集 | Omniglot(PyTorch可直接加载) | 少样本学习经典手写字符数据集 |
Mini-ImageNet(需要手动下载) | 更接近真实场景的小样本图像数据集 | |
开源库 | Torchmeta(https://github.com/tristandeleu/pytorch-meta) | PyTorch的元学习专用库,包含数据加载器和经典模型(如原型网络) |
论文 | 《Prototypical Networks for Few-shot Learning》(2017) | 原型网络的原始论文,详细推导数学原理 |
《Matching Networks for One Shot Learning》(2016) | 少样本学习早期经典论文,提出匹配网络 | |
教程 | PyTorch官方少样本学习教程(https://pytorch.org/tutorials/intermediate/meta_learning_tutorial.html) | 官方实战教程,包含MAML(另一种元学习算法)的实现 |
未来发展趋势与挑战
趋势1:与大语言模型(LLM)结合
GPT-3等大模型展示了强大的"上下文学习"(In-Context Learning)能力,本质上是少样本学习。未来少样本学习可能与LLM深度融合,通过"提示(Prompt)"让模型仅用少量示例完成复杂任务(如代码生成、文本分类)。
趋势2:跨模态少样本学习
当前少样本学习多集中在单模态(如图像),未来可能扩展到跨模态(如图像+文本)。例如,用少量"图像+描述"对,让模型学会根据文本描述生成或识别图像。
挑战1:小样本下的过拟合
支持集样本极少时,模型容易记住个别样本的噪声(如照片中的光线干扰),导致原型不准确。需要更鲁棒的特征提取方法(如引入先验知识或数据增强)。
挑战2:任务泛化能力
元训练阶段的历史任务可能与实际应用任务差异大(如用手写字符训练的模型,用于医学影像任务效果差)。需要研究"跨领域少样本学习",提升模型的泛化性。
总结:学到了什么?
核心概念回顾
- 支持集:AI学习新任务的"小抄"(少量样本)。
- 元学习:AI学会"如何快速学习"的超能力(不是死记硬背,而是掌握学习方法)。
- 原型网络:通过计算"典型特征"(原型),让AI用少量样本总结规律,匹配新样本。
概念关系回顾
支持集是"学习材料",元学习是"学习方法",原型网络是"具体技巧"。三者协作,让AI从"数据吃货"变身"学习高手",在样本极少的场景下也能高效工作。
思考题:动动小脑筋
- 如果支持集的样本质量很差(比如模糊的照片),原型网络的效果会如何?你能想到哪些方法提升这种情况下的性能?
- 除了欧氏距离,还可以用哪些距离度量(如余弦相似度)?它们的优缺点是什么?
- 假设你要开发一个"宠物医生助手",用少样本学习识别罕见宠物疾病(每个疾病只有3张病例图),你会如何设计支持集和查询集?
附录:常见问题与解答
Q1:少样本学习和零样本学习(Zero-shot Learning)有什么区别?
A:少样本学习需要少量样本(如5张),零样本学习不需要任何新类别样本(通过先验知识,如图像的文本描述)。例如,识别"霍加狓"时,少样本学习需要5张霍加狓的图,零样本学习需要知道"霍加狓是条纹像斑马、体型像长颈鹿的动物"的描述。
Q2:支持集的样本数(k-shot)越多越好吗?
A:不一定。当k增大到一定程度(如20),少样本学习和传统监督学习效果接近,但少样本学习的优势是"用更少样本达到相近效果"。实际中,k通常取1-5,因为更大的k可能增加标注成本。
Q3:原型网络和匹配网络(Matching Networks)有什么区别?
A:原型网络用"平均特征"作为原型,匹配网络用"注意力机制"让查询样本直接与支持集每个样本匹配(类似"记住所有小抄细节")。原型网络计算更简单,匹配网络理论上能利用支持集的更多细节,但容易过拟合。
扩展阅读 & 参考资料
- Snell J, Swersky K, Zemel R S. Prototypical networks for few-shot learning[C]//Advances in neural information processing systems. 2017: 4077-4087.(原型网络原始论文)
- 李宏毅元学习课程(https://speech.ee.ntu.edu.tw/~tlkagk/courses_ML20.html):用通俗易懂的方式讲解元学习原理。
- 《少样本学习:理论与实践》(机械工业出版社):系统介绍少样本学习的算法和应用。