Pytorch实现深度可分类卷积

import torch
import torch.nn as nn
from torchsummary import summary
import time

device = torch.device('cuda')

class Common_Convolution(nn.Module):
    def __init__(self, in_chs, out_chs):
        super(Common_Convolution, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_chs,
            out_channels=out_chs,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
            groups=1,
        )
    def forward(self, x):
        out = self.conv(x)
        return out

class Depth_Separable_Convolution(nn.Module):
    def __init__(self, in_chs, out_chs):
        super(Depth_Separable_Convolution, self).__init__()
        self.depthwise_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_chs,
                out_channels=in_chs,
                kernel_size=(3, 3),
                stride=(1, 1),
                groups=in_chs,
                padding=(1, 1),
            ),
            nn.BatchNorm2d(num_features=in_chs),
            nn.ReLU(),
        )
        self.pointwise_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_chs,
                out_channels=out_chs,
                kernel_size=(1, 1),
                stride=(1, 1),
                groups=1,
                padding=(0, 0),
            ),
            nn.BatchNorm2d(num_features=out_chs),
            nn.ReLU(),
        )
    def forward(self, x):
        x = self.depthwise_conv(x)
        out = self.pointwise_conv(x)
        return out

s1 = time.perf_counter()
common_conv = Common_Convolution(in_chs=3, out_chs=30).to(device)
print(summary(common_conv, input_size=(3, 64, 64)))
e1 = time.perf_counter()

s2 = time.perf_counter()
depthwise_conv = Depth_Separable_Convolution(in_chs=3, out_chs=3).to(device)
print(summary(depthwise_conv, input_size=(3, 64, 64)))
e2 = time.perf_counter()
s3 = time.perf_counter()
pointwise_conv = Depth_Separable_Convolution(in_chs=3, out_chs=30).to(device)
print(summary(pointwise_conv, input_size=(3, 64, 64)))
e3 = time.perf_counter()
print('Common convolution: ', e1 - s1)
print('Depth separable convolution: ', (e2 - s2) + (e3 - s3))



  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值