pytorch 卷积权重形状

# -*- coding: utf-8 -*-
import argparse
import os
import copy

import torch
from torch import nn
import numpy as np
import math


class CNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(CNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=3, padding=1),
            # nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
            # nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
            nn.PReLU(d)
        )
        self.mid_part1 = nn.Sequential(nn.Conv2d(d, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part2 = nn.Sequential(nn.Conv2d(d+s, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part3 = nn.Sequential(nn.Conv2d(d+s+s, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part4 = nn.Sequential(nn.Conv2d(d+s+s+s, s, kernel_size=3, padding=1), nn.PReLU(s))
        self.mid_part = nn.Sequential(nn.Conv2d(d+s+s+s+s, scale_factor ** 2, kernel_size=3, padding=1), nn.PReLU(scale_factor ** 2))
        
        # self.last_part = nn.ConvTranspose2d(d+s+s+s+s, num_channels, kernel_size=3, stride=scale_factor, padding=3//2,
        #                                     output_padding=scale_factor-1)
        self.last_part = nn.PixelShuffle(scale_factor)
        # 具体一点来说,Pixelshuffle会将shape为(∗,r2C,H,W)(∗,r2C,H,W)的Tensor给reshape成(∗,C,rH,rW)(∗,C,rH,rW)的Tensor。形式化地说,它的输入输出的shape如下: 
#- Input: (N,C∗upscale_factor2,H,W)(N,C∗upscale_factor2,H,W) 
#- Output:  (N,C,H∗upscale_factor,W∗upscale_factor)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part1:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part2:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part3:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part4:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        # nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        # nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        print(x.size())
        out1 = self.first_part(x)
        print(out1.size())
        temp = torch.cat([out1, x], 1)
        print(temp.size())
        out2 = self.mid_part1(out1)
        print(out2.size())
        cat2 = torch.cat([out1, out2], 1)
        print(cat2.size())
        out3 = self.mid_part2(cat2)
        print(out3.size())
        cat3 = torch.cat([out1, out2, out3], 1)
        print(cat3.size())
        out4 = self.mid_part3(cat3)
        print(out4.size())
        cat4 = torch.cat([out1, out2, out3, out4], 1)
        print(cat4.size())
        out5 = self.mid_part4(cat4)
        print(out5.size())
        print(self.mid_part4)
        for m in self.mid_part4:
            if isinstance(m, nn.Conv2d):
                print('weight形状:',m.weight.data.size())  #卷积的权重大小
                print(m.bias.data)
        cat5 = torch.cat([out1, out2, out3, out4, out5], 1)
        print(cat5.size())
        out6 = self.mid_part(cat5)
        print('out6.size():',out6.size())
        m = self.last_part
        print(m)
        out = self.last_part(out6)
        print(out.size())
        # print(m.weight.data.size())
        # print(m.bias.data)
        return out

if __name__ == '__main__':

    model = CNN(scale_factor = 3)
    print(model)
    input = torch.randn(12,1,28,36)
    with torch.no_grad():
        pre = model(input)
    # print(pre)
    # print(pre.size())
    pred = pre.clamp(0.0, 1.0)
    print('pred.size():',pred.size())
    print(pred[..., 0].shape)
    print(pred[..., 1].shape)
    print(pred[..., 2].shape)
    # pred = pred.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
    # print(pred.shape)
    # print(pred[..., 0].shape)
    params = sum(p.numel() for p in model.parameters()) #计算模型总参数量
    print(params)


结果:
CNN(
  (first_part): Sequential(
    (0): Conv2d(1, 56, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=56)
  )
  (mid_part1): Sequential(
    (0): Conv2d(56, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part2): Sequential(
    (0): Conv2d(68, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part3): Sequential(
    (0): Conv2d(80, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part4): Sequential(
    (0): Conv2d(92, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=12)
  )
  (mid_part): Sequential(
    (0): Conv2d(104, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=9)
  )
  (last_part): PixelShuffle(upscale_factor=3)
)
torch.Size([12, 1, 28, 36])
torch.Size([12, 56, 28, 36])
torch.Size([12, 57, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 68, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 80, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 92, 28, 36])
torch.Size([12, 12, 28, 36])
Sequential(
  (0): Conv2d(92, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): PReLU(num_parameters=12)
)
weight形状: torch.Size([12, 92, 3, 3])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
torch.Size([12, 104, 28, 36])
out6.size(): torch.Size([12, 9, 28, 36])
PixelShuffle(upscale_factor=3)
torch.Size([12, 1, 84, 108])
pred.size(): torch.Size([12, 1, 84, 108])
torch.Size([12, 1, 84])
torch.Size([12, 1, 84])
torch.Size([12, 1, 84])
41122  #模型总参数量


 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值