08 Shufflenet

 仅是个人笔记,代码学习资源来源B站博主霹雳吧啦Wz的个人空间_哔哩哔哩_bilibili

 Shufflenet对标mobilenet,轻量化网络

一、 理论知识

1. Shufflenet V1

mobilenet使用DW卷积,resnext使用group卷积等,会造成信息分割问题,如下左图

Shufflenet V1Block如下图:

左图步距为1的block,  右图步距为2的block

2. Shufflenet V2

提出四个原则:

  • 输入输出通道相同时,内存访问量MAC最小
  • 分组数过大会导致MAC增加
  • 碎片化操作对并行加速不友好
  • 逐元素操作会增加内存的消耗

更改Block如下:

                        左图是V1两个Block;                                右图是V2两个Block

 参数配置:

经过每个stage特征图长宽减半,通道数翻倍

二、 代码复现

1. 网络搭建

shuffle功能的实现:

def channel_shuffle(x: Tensor, groups: int) -> Tensor:   
    batch_size, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups

    x = x.view(batch_size, groups, channels_per_group, height, width)  
    # [batch_size, groups, channels_per_group, height, width] 

    x = torch.transpose(x, 1, 2).contiguous()   
    # 经过.transpose后变得不连续,先经过.contiguous()才能进行.view

    x = x.view(batch_size, -1, height, width)    # flatten
    return x

 2. 训练部分

 训练中进度条以及损失准确率等信息的显示:

######### 伪模板

model.train()
loss_function = torch.nn.CrossEntropyLoss()
optimizer.zero_grad()
data_loader = tqdm(data_loader, file=sys.stdout)

accu_loss = torch.zeros(1).to(device)  # 累计损失
accu_num = torch.zeros(1).to(device)   # 累计预测正确的样本数
sample_num = 0                         # 累计样本数

for step, data in enumerate(data_loader):
    #  读取数据
    images, labels = data
    #  预测
    pred = model(images.to(device))
    pred_classes = torch.max(pred, dim=1)[1]

    #  计数累加总样本、正确的样本
    sample_num += images.shape[0]
    accu_num += torch.eq(pred_classes, labels.to(device)).sum()

    #  损失计算
    loss = loss_function(pred, labels.to(device))
    loss.backward()
    #  损失累加
    accu_loss += loss.detach()

    #  打印信息
    data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
                                                                           accu_loss.item() / (step + 1),
                                                                              accu_num.item() / sample_num)

                                                                       # 总损失/轮数---平均损失 

return accu_loss.item() / (step + 1), accu_num.item() / sample_num  
         # 平均损失                        # 正确率

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值