盲超分中的核估计实现有关函数

以MANet代码为例

class MANet_s1(nn.Module):
    ''' stage1, train MANet'''

    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=10, gc=32, scale=4, pca_path='./pca_matrix_aniso21_15_x2.pth',
                 code_length=15, kernel_size=21, manet_nf=256, manet_nb=1, split=2):
        super(MANet_s1, self).__init__()
        self.scale = scale
        self.kernel_size = kernel_size

        self.kernel_estimation = MANet(in_nc=in_nc, kernel_size=kernel_size, nc=[manet_nf, manet_nf * 2],
                                       nb=manet_nb, split=split)

    def forward(self, x, gt_K):
        # kernel estimation
        kernel = self.kernel_estimation(x)
        kernel = F.interpolate(kernel, scale_factor=self.scale, mode='nearest').flatten(2).permute(0, 2, 1)
        kernel = kernel.view(-1, kernel.size(1), self.kernel_size, self.kernel_size)

        # no meaning
        with torch.no_grad():
            out = F.interpolate(x, scale_factor=self.scale, mode='nearest')

        return out, kernel
  • flatten()是对多维数据的降维函数。
    flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。
    python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维
    比如一个数据的维度是( S 0 , S 1 , S 2......... , S n ) , flatten(m)后的数据为 (S0,S1,S2,...,Sm-2,Sm-1,Sm*Sm+1*Sm+2*...*Sn)
  • 定义一个张量a为(2,3,4),a这个数据从0维展开,就是(2 ∗ 3 ∗ 4 ),维度就是(24),a从1维展开flatten(1),就是( 2 , 3 ∗ 4 ) ,也就是(2,12),a若是从2维展开flatten(2),那就是(2,3,4)和之前没有变化。

Permute(0,2,1)表示维度变换,原来的顺序是0,1,2  现在变换成0,2,1

例如:

>>> x = torch.randn(2, 3, 5) 
>>> x.size() 
torch.Size([2, 3, 5]) 
>>> x.permute(2, 0, 1).size() 
torch.Size([5, 2, 3])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值