简介:
由ConvNets学习的表示已被证明对局部纹理有强烈的偏差,从而严重损害全局信息,急需改进。因此,已经做出努力来升级宏观层次的架构和上下文聚合模块。大多数现代深度神经网络倾向于编码极低或高复杂度的博弈理论相互作用,而不是最具鉴别性的中间交互作用,这限制了它们对复杂样本的表示能力和鲁棒性。与之前的尝试不同,MogaNet通过多阶博弈论交互的视角研究了现代卷积网络的表示能力,为解释基于博弈论的深度架构中编码的特征交互行为和效果提供了一个新的视角。
MogaNet:
图1为MogaNet宏观架构图,分为四个阶段。对于阶段i,输入的图像或特征首先被输入到嵌入stem中以调节特征分辨率并嵌入到Ci维度。假设输 入图像的分辨率为H×W,则四个阶段的特征分别为H/4×W/4 ,H/8×W/8 ,H/16×W/16和H/32×W/32的分辨率。然后,嵌入的特征流到NiMoga块中,由空间和通道聚合块组成,用于进一步的上下文提取和聚合。在最终输出后,添加GAP和线性层用于分类任务。
多阶门控聚合:
MogaNet提出一种空间聚合块SMixer(·),用于在统一设计中聚合多阶上下文,如下图所示,由两个级联组件组成。
其中,Moga(·)是一个多阶门控聚合模块,包括门控Fϕ(·)和上下文分支Gψ(·)。 具体细节:
class MogaBlock(nn.Module):
def __init__(self,
embed_dims,
ffn_ratio=4.,
drop_rate=0.,
drop_path_rate=0.,
act_type='GELU',
norm_type='BN',
init_value=1e-5,
attn_dw_dilation=[1, 2, 3],
attn_channel_split=[1, 3, 4],
attn_act_type='SiLU',
attn_force_fp32=False,
):
super(MogaBlock, self).__init__()
self.out_channels = embed_dims
self.norm1 = build_norm_layer(norm_type, embed_dims)
# spatial attention
self.attn = MultiOrderGatedAggregation(
embed_dims,
attn_dw_dilation=attn_dw_dilation,
attn_channel_split=attn_channel_split,
attn_act_type=attn_act_type,
attn_force_fp32=attn_force_fp32,
)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_type, embed_dims)
# channel MLP
mlp_hidden_dim = int(embed_dims * ffn_ratio)
self.mlp = ChannelAggregationFFN( # DWConv + Channel Aggregation FFN
embed_dims=embed_dims,
feedforward_channels=mlp_hidden_dim,
act_type=act_type,
ffn_drop=drop_rate,
)
# init layer scale
self.layer_scale_1 = nn.Parameter(
init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
def forward(self, x):
# spatial
identity = x
x = self.layer_scale_1 * self.attn(self.norm1(x))
x = identity + self.drop_path(x)
# channel
identity = x
x = self.layer_scale_2 * self.mlp(self.norm2(x))
x = identity + self.drop_path(x)
return x
通道聚合块:是一个通道聚合的前馈神经网络,用于特征提取和降维。
class ChannelAggregationFFN(nn.Module):
def __init__(self,
embed_dims,
feedforward_channels,
kernel_size=3,
act_type='GELU',
ffn_drop=0.):
super(ChannelAggregationFFN, self).__init__()
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.fc1 = nn.Conv2d(
in_channels=embed_dims,
out_channels=self.feedforward_channels,
kernel_size=1)
self.dwconv = nn.Conv2d(
in_channels=self.feedforward_channels,
out_channels=self.feedforward_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True,
groups=self.feedforward_channels)
self.act = build_act_layer(act_type)
self.fc2 = nn.Conv2d(
in_channels=feedforward_channels,
out_channels=embed_dims,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)
self.decompose = nn.Conv2d(
in_channels=self.feedforward_channels, # C -> 1
out_channels=1, kernel_size=1,
)
self.sigma = ElementScale(
self.feedforward_channels, init_value=1e-5, requires_grad=True)
self.decompose_act = build_act_layer(act_type)
def feat_decompose(self, x):
# x_d: [B, C, H, W] -> [B, 1, H, W]
x = x + self.sigma(x - self.decompose_act(self.decompose(x)))
return x
def forward(self, x):
# proj 1
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
# proj 2
x = self.feat_decompose(x)
x = self.fc2(x)
x = self.drop(x)
return x
多阶DWConv层 :目的是通过不同的深度可分离卷积操作,以不同的通道分割比例来提取多个阶段的特征。
class MultiOrderDWConv(nn.Module):
def __init__(self,
embed_dims,
dw_dilation=[1, 2, 3,],
channel_split=[1, 3, 4,],
):
super(MultiOrderDWConv, self).__init__()
self.split_ratio = [i / sum(channel_split) for i in channel_split]
self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
self.embed_dims = embed_dims
assert len(dw_dilation) == len(channel_split) == 3
assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
assert embed_dims % sum(channel_split) == 0
# basic DW conv
self.DW_conv0 = nn.Conv2d(
in_channels=self.embed_dims,
out_channels=self.embed_dims,
kernel_size=5,
padding=(1 + 4 * dw_dilation[0]) // 2,
groups=self.embed_dims,
stride=1, dilation=dw_dilation[0],
)
# DW conv 1
self.DW_conv1 = nn.Conv2d(
in_channels=self.embed_dims_1,
out_channels=self.embed_dims_1,
kernel_size=5,
padding=(1 + 4 * dw_dilation[1]) // 2,
groups=self.embed_dims_1,
stride=1, dilation=dw_dilation[1],
)
# DW conv 2
self.DW_conv2 = nn.Conv2d(
in_channels=self.embed_dims_2,
out_channels=self.embed_dims_2,
kernel_size=7,
padding=(1 + 6 * dw_dilation[2]) // 2,
groups=self.embed_dims_2,
stride=1, dilation=dw_dilation[2],
)
# a channel convolution
self.PW_conv = nn.Conv2d( # point-wise convolution
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=1)
def forward(self, x):
x_0 = self.DW_conv0(x)
x_1 = self.DW_conv1(
x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
x_2 = self.DW_conv2(
x_0[:, self.embed_dims-self.embed_dims_2:, ...])
x = torch.cat([
x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
x = self.PW_conv(x)
return x
多阶门控聚合操作:用于特征聚合任务
class MultiOrderGatedAggregation(nn.Module):
def __init__(self,
embed_dims,
attn_dw_dilation=[1, 2, 3],
attn_channel_split=[1, 3, 4],
attn_act_type='SiLU',
attn_force_fp32=False,
):
super(MultiOrderGatedAggregation, self).__init__()
self.embed_dims = embed_dims
self.attn_force_fp32 = attn_force_fp32
self.proj_1 = nn.Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.gate = nn.Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.value = MultiOrderDWConv(
embed_dims=embed_dims,
dw_dilation=attn_dw_dilation,
channel_split=attn_channel_split,
)
self.proj_2 = nn.Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
# activation for gating and value
self.act_value = build_act_layer(attn_act_type)
self.act_gate = build_act_layer(attn_act_type)
# decompose
self.sigma = ElementScale(
embed_dims, init_value=1e-5, requires_grad=True)
def feat_decompose(self, x):
x = self.proj_1(x)
# x_d: [B, C, H, W] -> [B, C, 1, 1]
x_d = F.adaptive_avg_pool2d(x, output_size=1)
x = x + self.sigma(x - x_d)
x = self.act_value(x)
return x
def forward_gating(self, g, v):
with torch.autocast(device_type='cuda', enabled=False):
g = g.to(torch.float32)
v = v.to(torch.float32)
return self.proj_2(self.act_gate(g) * self.act_gate(v))
def forward(self, x):
shortcut = x.clone()
# proj 1x1
x = self.feat_decompose(x)
# gating and value branch
g = self.gate(x)
v = self.value(x)
# aggregation
if not self.attn_force_fp32:
x = self.proj_2(self.act_gate(g) * self.act_gate(v))
else:
x = self.forward_gating(self.act_gate(g), self.act_gate(v))
x = x + shortcut
return x
测试:
我使用一个病害识别的数据集来测试MogaNet:
数据集:Plantvillage数据集,"Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy"
论文细节:Efficient Multi-order Gated Aggregation Network
训练代码:
import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from models.moganet import moganet_xtiny
from torch.autograd import Variable
from torchvision import datasets
torch.backends.cudnn.benchmark = False
import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch,model_ema):
model.train()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
total_num = len(train_loader.dataset)
print(total_num, len(train_loader))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), Variable(target).to(device,non_blocking=True)
samples, targets = mixup_fn(data, target)
output = model(data)
optimizer.zero_grad()
if use_amp:
with torch.cuda.amp.autocast():
loss = torch.nan_to_num(criterion_train(output, targets))
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
scaler.step(optimizer)
scaler.update()
else:
loss = criterion_train(output, targets)
loss.backward()
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
lr = optimizer.state_dict()['param_groups'][0]['lr']
loss_meter.update(loss.item(), target.size(0))
acc1, acc5 = accuracy(output, target, topk=(1, 5))
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
if (batch_idx + 1) % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
ave_loss =loss_meter.avg
acc = acc1_meter.avg
print('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch, ave_loss, acc))
return ave_loss, acc
# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
global Best_ACC
model.eval()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
total_num = len(test_loader.dataset)
print(total_num, len(test_loader))
val_list = []
pred_list = []
for data, target in test_loader:
for t in target:
val_list.append(t.data.item())
data, target = data.to(device,non_blocking=True), target.to(device,non_blocking=True)
output = model(data)
loss = criterion_val(output, target)
_, pred = torch.max(output.data, 1)
for p in pred:
pred_list.append(p.data.item())
acc1, acc5 = accuracy(output, target, topk=(1, 5))
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc = acc1_meter.avg
print('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\n'.format(
loss_meter.avg, acc))
if acc > Best_ACC:
if isinstance(model, torch.nn.DataParallel):
torch.save(model.module, file_dir + '/' + 'best.pth')
else:
torch.save(model, file_dir + '/' + 'best.pth',pickle_protocol=0)
Best_ACC = acc
if isinstance(model, torch.nn.DataParallel):
state = {
'epoch': epoch,
'state_dict': model.module.state_dict(),
'Best_ACC':Best_ACC
}
if use_ema:
state['state_dict_ema']=model.module.state_dict()
torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
else:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'Best_ACC': Best_ACC
}
if use_ema:
state['state_dict_ema']=model.state_dict()
torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
return val_list, pred_list, loss_meter.avg, acc
def seed_everything(seed=0):
os.environ['PYHTONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
file_dir = 'checkpoints/maganet/'
if os.path.exists(file_dir):
print('true')
os.makedirs(file_dir,exist_ok=True)
else:
os.makedirs(file_dir)
model_lr = 1e-3
BATCH_SIZE = 8
EPOCHS = 100
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
use_amp = True
use_dp = True
classes = 3
resume =None
CLIP_GRAD = 5.0
Best_ACC = 0
use_ema=False
model_ema_decay=0.9998
start_epoch=1
seed=0
seed_everything(seed)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])
])
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
prob=0.1, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=classes)
dataset_train = datasets.ImageFolder('data/train', transform=transform)
dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
with open('class.txt', 'w') as file:
file.write(str(dataset_train.class_to_idx))
with open('class.json', 'w', encoding='utf-8') as file:
file.write(json.dumps(dataset_train.class_to_idx))
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,drop_last=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
criterion_train = SoftTargetCrossEntropy()
criterion_val = torch.nn.CrossEntropyLoss()
model_ft = moganet_xtiny()
num_fr=model_ft.head.in_features
model_ft.head =nn.Linear(num_fr,classes)
print(model_ft)
if resume:
model=torch.load(resume)
print(model['state_dict'].keys())
model_ft.load_state_dict(model['state_dict'],strict = False)
Best_ACC=model['Best_ACC']
start_epoch=model['epoch']+1
model_ft.to(DEVICE)
optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=30, eta_min=1e-9)
if use_amp:
scaler = torch.cuda.amp.GradScaler()
if torch.cuda.device_count() > 1 and use_dp:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model_ft = torch.nn.DataParallel(model_ft)
if use_ema:
model_ema = ModelEma(
model_ft,
decay=model_ema_decay,
device=DEVICE,
resume=resume)
else:
model_ema=None
# 训练与验证
is_set_lr = False
log_dir = {}
train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []
epoch_info = []
if resume and os.path.isfile(file_dir+"result.json"):
with open(file_dir+'result.json', 'r', encoding='utf-8') as file:
logs = json.load(file)
train_acc_list = logs['train_acc']
train_loss_list = logs['train_loss']
val_acc_list = logs['val_acc']
val_loss_list = logs['val_loss']
epoch_list = logs['epoch_list']
for epoch in range(start_epoch, EPOCHS + 1):
epoch_list.append(epoch)
log_dir['epoch_list'] = epoch_list
train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)
train_loss_list.append(train_loss)
train_acc_list.append(train_acc)
log_dir['train_acc'] = train_acc_list
log_dir['train_loss'] = train_loss_list
if use_ema:
val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)
else:
val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)
val_loss_list.append(val_loss)
val_acc_list.append(val_acc)
log_dir['val_acc'] = val_acc_list
log_dir['val_loss'] = val_loss_list
log_dir['best_acc'] = Best_ACC
with open(file_dir + '/result.json', 'w', encoding='utf-8') as file:
file.write(json.dumps(log_dir))
print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
epoch_info.append({
'epoch': epoch,
'train_loss': train_loss,
'train_acc': train_acc,
'val_loss': val_loss,
'val_acc': val_acc
})
with open('epoch_info.txt', 'w') as f:
for epoch_data in epoch_info:
f.write(f"Epoch: {epoch_data['epoch']}\n")
f.write(f"Train Loss: {epoch_data['train_loss']}\n")
f.write(f"Train Acc: {epoch_data['train_acc']}\n")
f.write(f"Val Loss: {epoch_data['val_loss']}\n")
f.write(f"Val Acc: {epoch_data['val_acc']}\n")
f.write("\n")
if epoch < 600:
cosine_schedule.step()
else:
if not is_set_lr:
for param_group in optimizer.param_groups:
param_group["lr"] = 1e-6
is_set_lr = True
fig = plt.figure(1)
plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')
# 显示图例
plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')
plt.legend(["Train Loss", "Val Loss"], loc="upper right")
plt.xlabel(u'epoch')
plt.ylabel(u'loss')
plt.title('Model Loss ')
plt.savefig(file_dir + "/loss.png")
plt.close(1)
fig2 = plt.figure(2)
plt.plot(epoch_list, train_acc_list, 'g-', label=u'Train Acc')
plt.plot(epoch_list, val_acc_list, 'y-', label=u'Val Acc')
plt.legend(["Train Acc", "Val Acc"], loc="lower right")
plt.title("Model Acc")
plt.ylabel("acc")
plt.xlabel("epoch")
plt.savefig(file_dir + "/acc.png")
plt.close(2)
结果:该模型在100轮内达到的最优acc为97.34%,验证集达到99.93%
总结:
MogaNet是一种从多阶博弈论交互的新观点出发的计算高效的纯卷积网络架构。通过特别关注多阶博弈交互,设计了一个统一的多阶遗传块,有效捕捉了跨空间和信道空间的鲁棒多阶上下 文。
以上为全部内容!