仅是个人笔记,代码学习资源来源B站博主霹雳吧啦Wz的个人空间_哔哩哔哩_bilibili
Shufflenet对标mobilenet,轻量化网络
一、 理论知识
1. Shufflenet V1
mobilenet使用DW卷积,resnext使用group卷积等,会造成信息分割问题,如下左图
Shufflenet V1Block如下图:
左图步距为1的block, 右图步距为2的block
2. Shufflenet V2
提出四个原则:
- 输入输出通道相同时,内存访问量MAC最小
- 分组数过大会导致MAC增加
- 碎片化操作对并行加速不友好
- 逐元素操作会增加内存的消耗
更改Block如下:
左图是V1两个Block; 右图是V2两个Block
参数配置:
经过每个stage特征图长宽减半,通道数翻倍
二、 代码复现
1. 网络搭建
shuffle功能的实现:
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
x = x.view(batch_size, groups, channels_per_group, height, width)
# [batch_size, groups, channels_per_group, height, width]
x = torch.transpose(x, 1, 2).contiguous()
# 经过.transpose后变得不连续,先经过.contiguous()才能进行.view
x = x.view(batch_size, -1, height, width) # flatten
return x
2. 训练部分
训练中进度条以及损失准确率等信息的显示:
######### 伪模板
model.train()
loss_function = torch.nn.CrossEntropyLoss()
optimizer.zero_grad()
data_loader = tqdm(data_loader, file=sys.stdout)
accu_loss = torch.zeros(1).to(device) # 累计损失
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
sample_num = 0 # 累计样本数
for step, data in enumerate(data_loader):
# 读取数据
images, labels = data
# 预测
pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
# 计数累加总样本、正确的样本
sample_num += images.shape[0]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
# 损失计算
loss = loss_function(pred, labels.to(device))
loss.backward()
# 损失累加
accu_loss += loss.detach()
# 打印信息
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)
# 总损失/轮数---平均损失
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
# 平均损失 # 正确率