micronet ICCV2021

import torch
import torch.nn as nn
import torch.nn.functional as F

def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


########################################################################
# sigmoid and tanh
########################################################################
# h_sigmoid (x: [-3 3], y: [0, h_max]]
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True, h_max=1):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
        self.h_max = h_max / 6

    def forward(self, x):
        return self.relu(x + 3) * self.h_max

# h_tanh x: [-3, 3], y: [-h_max, h_max]
class h_tanh(nn.Module):
    def __init__(self, inplace=True, h_max=1):
        super(h_tanh, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
        self.h_max = h_max

    def forward(self, x):
        return self.relu(x + 3)*self.h_max / 3 - self.h_max


########################################################################
# wrap functions
########################################################################

def get_act_layer(inp, oup, mode='SE1', act_relu=True, act_max=2, act_bias=True, init_a=[1.0, 0.0], reduction=4, init_b=[0.0, 0.0], g=None, act='relu', expansion=True):
    layer = None
    if mode == 'SE1':
        layer = nn.Sequential(
            SELayer(inp, oup, reduction=reduction),
            nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
        )
    elif mode == 'SE0':
        layer = nn.Sequential(
            SELayer(inp, oup, reduction=reduction),
        )
    elif mode == 'NA':
        layer = nn.ReLU6(inplace=True) if act_relu else nn.Sequential()
    elif mode == 'LeakyReLU':
        layer = nn.LeakyReLU(inplace=True) if act_relu else nn.Sequential()
    elif mode == 'RReLU':
        layer = nn.RReLU(inplace=True) if act_relu else nn.Sequential()
    elif mode == 'PReLU':
        layer = nn.PReLU() if act_relu else nn.Sequential()
    elif mode == 'DYShiftMax':
        layer = DYShiftMax(inp, oup, act_max=act_max, act_relu=act_relu, init_a=init_a, reduction=reduction, init_b=init_b, g=g, expansion=expansion)
    return layer

########################################################################
# dynamic activation layers (SE, DYShiftMax, etc)
########################################################################

class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4):
        super(SELayer, self).__init__()
        self.oup = oup
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # determine squeeze
        squeeze = get_squeeze_channels(inp, reduction)
        print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))


        self.fc = nn.Sequential(
                nn.Linear(inp, squeeze),
                nn.ReLU(inplace=True),
                nn.Linear(squeeze, oup),
                h_sigmoid()
        )

    def forward(self, x):
        if isinstance(x, list):
            x_in = x[0]
            x_out = x[1]
        else:
            x_in = x
            x_out = x
        b, c, _, _ = x_in.size()
        y = self.avg_pool(x_in).view(b, c)
        y = self.fc(y).view(b, self.oup, 1, 1)
        return x_out * y

class DYShiftMax(nn.Module):
    def __init__(self, inp, oup, reduction=4, act_max=1.0, act_relu=True, init_a=[0.0, 0.0], init_b=[0.0, 0.0], relu_before_pool=False, g=None, expansion=False):
        super(DYShiftMax, self).__init__()
        self.oup = oup
        self.act_max = act_max * 2
        self.act_relu = act_relu
        self.avg_pool = nn.Sequential(
                nn.ReLU(inplace=True) if relu_before_pool == True else nn.Sequential(),
                nn.AdaptiveAvgPool2d(1)
            )

        self.exp = 4 if act_relu else 2
        self.init_a = init_a
        self.init_b = init_b

        # determine squeeze
        squeeze = _make_divisible(inp // reduction, 4)
        if squeeze < 4:
            squeeze = 4
        print('reduction: {}, squeeze: {}/{}'.format(reduction, inp, squeeze))
        print('init-a: {}, init-b: {}'.format(init_a, init_b))

        self.fc = nn.Sequential(
                nn.Linear(inp, squeeze),
                nn.ReLU(inplace=True),
                nn.Linear(squeeze, oup*self.exp),
                h_sigmoid()
        )
        if g is None:
            g = 1
        self.g = g[1]
        if self.g !=1  and expansion:
            self.g = inp // self.g
        print('group shuffle: {}, divide group: {}'.format(self.g, expansion))
        self.gc = inp//self.g
        index=torch.Tensor(range(inp)).view(1,inp,1,1)#b,c,1,1
        print('index=',index.shape)
        # print('range(inp)=',range(inp))#整数序列
        print('self.g: {}, self.gc: {}'.format(self.g, self.gc))
        index=index.view(1,self.g,self.gc,1,1)
        indexgs = torch.split(index, [1, self.g-1], dim=1)
        print('indexgs[0]=',indexgs[0].shape)
        print('indexgs[1]=', indexgs[1].shape)
        indexgs = torch.cat((indexgs[1], indexgs[0]), dim=1)
        indexs = torch.split(indexgs, [1, self.gc-1], dim=2)
        print('indexs[0]=',indexs[0].shape)
        print('indexs[1]=', indexs[1].shape)
        indexs = torch.cat((indexs[1], indexs[0]), dim=2)
        self.index = indexs.view(inp).type(torch.LongTensor)
        print('self.index=',self.index.shape)
        self.expansion = expansion

    def forward(self, x):
        x_in = x
        x_out = x

        b, c, _, _ = x_in.size()
        y = self.avg_pool(x_in).view(b, c)
        y = self.fc(y).view(b, self.oup*self.exp, 1, 1)
        print('y = ',y.shape)
        y = (y-0.5) * self.act_max
        print('y_max = ', y.shape)

        n2, c2, h2, w2 = x_out.size()
        x2 = x_out[:,self.index,:,:]
        print('x2 = ',x2.shape)

        if self.exp == 4:
            a1, b1, a2, b2 = torch.split(y, self.oup, dim=1)
            print('a1 = ',a1.shape)#b,c,1,1
            print('torch.split(y, self.oup, dim=1)',torch.split(y, self.oup, dim=1)[0].shape)
            print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[1].shape)
            print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[2].shape)
            print('torch.split(y, self.oup, dim=1)', torch.split(y, self.oup, dim=1)[3].shape)

            a1 = a1 + self.init_a[0]
            print('a1_max=',a1.shape)#b,c,1,1
            a2 = a2 + self.init_a[1]

            b1 = b1 + self.init_b[0]
            b2 = b2 + self.init_b[1]

            z1 = x_out * a1 + x2 * b1
            z2 = x_out * a2 + x2 * b2

            out = torch.max(z1, z2)

        elif self.exp == 2:
            a1, b1 = torch.split(y, self.oup, dim=1)
            a1 = a1 + self.init_a[0]
            b1 = b1 + self.init_b[0]
            out = x_out * a1 + x2 * b1

        return out

def get_squeeze_channels(inp, reduction):
    if reduction == 4:
        squeeze = inp // reduction
    else:
        squeeze = _make_divisible(inp // reduction, 4)
    return squeeze

a = torch.rand(2,16,16,16)
b = DYShiftMax(16,16,reduction=4, act_max=1.0, act_relu=True, init_a=[2.0, 2.0], init_b=[2.0, 2.0], relu_before_pool=False, g=[4,4], expansion=False)
c = b(a)
# print(b)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值