以MANet代码为例
class MANet_s1(nn.Module):
''' stage1, train MANet'''
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=10, gc=32, scale=4, pca_path='./pca_matrix_aniso21_15_x2.pth',
code_length=15, kernel_size=21, manet_nf=256, manet_nb=1, split=2):
super(MANet_s1, self).__init__()
self.scale = scale
self.kernel_size = kernel_size
self.kernel_estimation = MANet(in_nc=in_nc, kernel_size=kernel_size, nc=[manet_nf, manet_nf * 2],
nb=manet_nb, split=split)
def forward(self, x, gt_K):
# kernel estimation
kernel = self.kernel_estimation(x)
kernel = F.interpolate(kernel, scale_factor=self.scale, mode='nearest').flatten(2).permute(0, 2, 1)
kernel = kernel.view(-1, kernel.size(1), self.kernel_size, self.kernel_size)
# no meaning
with torch.no_grad():
out = F.interpolate(x, scale_factor=self.scale, mode='nearest')
return out, kernel
- flatten()是对多维数据的降维函数。
flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。
python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。
比如一个数据的维度是( S 0 , S 1 , S 2......... , S n ) , flatten(m)后的数据为 (S0,S1,S2,...,Sm-2,Sm-1,Sm*Sm+1*Sm+2*...*Sn) - 定义一个张量a为(2,3,4),a这个数据从0维展开,就是(2 ∗ 3 ∗ 4 ),维度就是(24),a从1维展开flatten(1),就是( 2 , 3 ∗ 4 ) ,也就是(2,12),a若是从2维展开flatten(2),那就是(2,3,4)和之前没有变化。
Permute(0,2,1)表示维度变换,原来的顺序是0,1,2 现在变换成0,2,1
例如:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(2, 0, 1).size()
torch.Size([5, 2, 3])