来源:知乎—tik boa
地址:https://zhuanlan.zhihu.com/p/444310267
01
概要
最近只要你还在cv这个圈子中卷着,就一定听说过kaiming大神的MAE。在这之前自监督的范式,主要基于对比学习来做,诸如Moco系列、SimCLR、BYOL、SimSiam、SimCSE等等。后续主要通过重建的方式进行自监督习,BEiT以重建token的方式作为proxy task,而MAE表明重建pixel完全是可行的,即使在75%的mask比率下。
此外,基于MAE的自监督方式,在下游任务中表现也相当优异。如下表所示,在coco数据集上,基于MAE的vit模型,相较于有监督、moco、BEiT等在box/mask AP上均有大幅度提升。
coco目标检测与实例分割
在此之间,训练custom dataset的下游任务网络时,基本基于ImageNet预训练模型,但是这存在两大问题:
custom dataset与ImageNet类似还好,但若是医学/工业数据,其实是两个domain,存在gap;
custom dataset的有标注数据获取费时、费力,而大量无标注数据却未充分利用;
为此,本篇文章将叙述如何手撸一个MAE,搞懂细节,知其然并知其所以然。然后在此基础上,迁移到mm系列框架中(mmdet/mmcls/mmseg),迅速构建一个属于自己的强有力的backbone。废话不多说了,让我们先进入MAE的世界。
02
MAE (Masked Autoencoders)
MAE全文非常简洁明了、整篇文章没有一个公式。即使你初入cv,读完整篇文章也大概知道在讲什么了。MAE编解码结构如下所示,编码器输入未mask的patch, 解码器输入编码器输出+mask patch的embeding。
MAE
对全文进行切分,主要涉及五大部分:
encoder 构建;
decoder 构建;
MAE 构建;
pretaining pipeline;
finetuning pipeline;
接下来将对以上五个部分抽丝剥茧的介绍。
encoder
编码器采用ViT结构,但去除了vit中的分类head,剩余结构包含以下部分:
Patch embedding
Transformer encoder
Attention
MLP
patch embedding
该部分主要将一张图片切块,进行维度压缩,然后转为为一个序列,具体形式如下所示。
patch embedding
该部分代码:
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
attention
attention 机制实质上就是一个寻址过程,通过给定一个任务相关的查询 Query 向量 Q,通过计算与 Key 的注意力分布并附加在 Value 上,从而计算 Attention Value。
Attention 机制计算过程大致可以分成三步:
① 信息输入:将 Q,K,V 输入模型
用 表示输入权重向量
② 计算注意力分布 α:通过计算 Q 和 K 进行点积计算相关度,并通过 softmax 计算分数
另 ,通过 softmax 计算注意力权重, α
我们将α_i称之为注意力概率分布, 为注意力打分机制,常见的有如下几种:
加性模型:
点积模型:
缩放点积模型:
双线性模型:
③ 信息加权平均:注意力分布α_i来解释在上下文查询q_i时,第i个信息受关注程度。
α
attention结构
Multi-head Attention就是将 Scaled Dot-Product Attention 过程做 H 次,再把输出合并起来:
多头注意力机制的公式如下:
代码如下:
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
MLP
mlp较为简单,就是两个线性层的叠加,没什么可说的,代码如下:
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
组装编码器
其中pos_embed可采用两种形式:
① 可学习的:self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
② 生成的:self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
生成形式如下:
def get_sinusoid_encoding_table(n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
最后MAE的encoder代码如下:
class PretrainVisionTransformerEncoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
use_learnable_pos_emb=False):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
else:
# sine-cosine positional embeddings
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
# trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x, mask):
x = self.patch_embed(x)
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
B, _, C = x.shape
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
for blk in self.blocks:
x_vis = blk(x_vis)
x_vis = self.norm(x_vis)
return x_vis
def forward(self, x, mask):
x = self.forward_features(x, mask)
x = self.head(x)
return x
NOTE:
通过 x_vis = x[~mask].reshape(B, -1, C) 方式 使得encoder 仅在未mask的patch上操作
decoder
解码器和编码器结构类似,值得注意的是:
输入是全部patch embeding;
输出是mask patch的解码输出;
这部分后续会细讲
class PretrainVisionTransformerDecoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_patches=196,
):
super().__init__()
self.num_classes = num_classes
assert num_classes == 3 * patch_size ** 2
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_size = patch_size
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, return_token_num):
for blk in self.blocks:
x = blk(x)
x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
return x
MAE
MAE的encoder结合decoder需要注意以下几点:
endocer(vitbase)通道为768, decoder为384,输入decoder时候,pos_embed重新生成;
decoder输出类别为768= (16 x 16 x3), 为patch的RGB像素值;
decoder 仅输出 mask patch的结果
即上文的:x = self.head(self.norm(x[:, -return_token_num:])) #concat过程 后面是mask的预测
整体过程如下所示,看完这个应该没啥不明白的了吧:
MAE代码如下:
class PretrainVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
encoder_in_chans=3,
encoder_num_classes=0,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
decoder_num_classes=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
use_learnable_pos_emb=False,
num_classes=0, # avoid the error from create_fn in timm
in_chans=0, # avoid the error from create_fn in timm
):
super().__init__()
self.encoder = PretrainVisionTransformerEncoder(
img_size=img_size,
patch_size=patch_size,
in_chans=encoder_in_chans,
num_classes=encoder_num_classes,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
use_learnable_pos_emb=use_learnable_pos_emb)
self.decoder = PretrainVisionTransformerDecoder(
patch_size=patch_size,
num_patches=self.encoder.patch_embed.num_patches,
num_classes=decoder_num_classes,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values)
self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)
trunc_normal_(self.mask_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, mask):
x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N, C = x_vis.shape
# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accorddingly.
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
# notice: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
return x
pretaining pipeline
预训练过程,有个需要注意的地方:
对每个path的像素进行归一化,减少训练难度
images_squeeze = rearrange(unnorm_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=patch_size, p2=patch_size)
images_norm = (images_squeeze - images_squeeze.mean(dim=-2, keepdim=True)
) / (images_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
images_patch = rearrange(images_norm, 'b n p c -> b n (p c)')
labels = images_patch[bool_masked_pos].reshape(B, -1, C)
其中bool_masked_pos为mask的图像patch索引,生成方式如下所示:
random mask generator
class RandomMaskingGenerator:
def __init__(self, input_size, mask_ratio):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_mask = int(mask_ratio * self.num_patches)
def __repr__(self):
repr_str = "Maks: total patches {}, mask patches {}".format(
self.num_patches, self.num_mask
)
return repr_str
def __call__(self):
mask = np.hstack([
np.zeros(self.num_patches - self.num_mask),
np.ones(self.num_mask),
])
np.random.shuffle(mask)
return mask # [196]
训练代码如下:
def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16,
lr_scheduler=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None):
model.train()
loss_func = nn.MSELoss()
for step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
it = start_steps + step
if lr_schedule_values is not None or wd_schedule_values is not None:
for i, param_group in enumerate(optimizer.param_groups):
if lr_schedule_values is not None:
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
if wd_schedule_values is not None and param_group["weight_decay"] > 0:
param_group["weight_decay"] = wd_schedule_values[it]
images, bool_masked_pos = batch
images = images.to(device, non_blocking=True)
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
# import pdb; pdb.set_trace()
with torch.no_grad():
# calculate the predict label
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None]
unnorm_images = images * std + mean # in [0, 1]
images_squeeze = rearrange(unnorm_images, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=patch_size, p2=patch_size)
images_norm = (images_squeeze - images_squeeze.mean(dim=-2, keepdim=True)
) / (images_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
images_patch = rearrange(images_norm, 'b n p c -> b n (p c)')
B, _, C = images_patch.shape
labels = images_patch[bool_masked_pos].reshape(B, -1, C)
with torch.cuda.amp.autocast():
outputs = model(images, bool_masked_pos)
loss = loss_func(input=outputs, target=labels)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=is_second_order)
if lr_scheduler is not None:
lr_scheduler.step_update(start_steps + step)
finetuning pipeline
微调过程,有以下几个需要注意的地方:
层级差分学习率:
class LayerDecayValueAssigner(object):
def __init__(self, values):
self.values = values
def get_scale(self, layer_id):
return self.values[layer_id]
assigner = LayerDecayValueAssigner(list(0.75 ** (12 + 1 - i) for i in range(12 + 2)))
encoder 12 层,每层学习率的缩放值为:
[0.023757264018058777, 0.03167635202407837, 0.04223513603210449, 0.056313514709472656, 0.07508468627929688, 0.1001129150390625, 0.13348388671875, 0.177978515625, 0.2373046875, 0.31640625, 0.421875, 0.5625, 0.75, 1.0]
采用这种策略主要是因为,底层主要负责一些固定的模式、高层负责语义。差分学习率类似于从底层到高层逐渐解冻。
利用更广泛的数据增强与正则:
augmentation RandAug (9, 0.5)
label smoothing 0.1
mixup 0.8
cutmix 1.0
drop path 0.1 (B/L) 0.2 (H)
03
结果
放一张效果图吧:
无缝嵌入mm系列框架,待续~
猜您喜欢: