# -*- coding: utf-8 -*-
import argparse
import os
import copy
import torch
from torch import nn
import numpy as np
import math
class CNN(nn.Module):
def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
super(CNN, self).__init__()
self.first_part = nn.Sequential(
nn.Conv2d(num_channels, d, kernel_size=3, padding=1),
# nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
# nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
nn.PReLU(d)
)
self.mid_part1 = nn.Sequential(nn.Conv2d(d, s, kernel_size=3, padding=1), nn.PReLU(s))
self.mid_part2 = nn.Sequential(nn.Conv2d(d+s, s, kernel_size=3, padding=1), nn.PReLU(s))
self.mid_part3 = nn.Sequential(nn.Conv2d(d+s+s, s, kernel_size=3, padding=1), nn.PReLU(s))
self.mid_part4 = nn.Sequential(nn.Conv2d(d+s+s+s, s, kernel_size=3, padding=1), nn.PReLU(s))
self.mid_part = nn.Sequential(nn.Conv2d(d+s+s+s+s, scale_factor ** 2, kernel_size=3, padding=1), nn.PReLU(scale_factor ** 2))
# self.last_part = nn.ConvTranspose2d(d+s+s+s+s, num_channels, kernel_size=3, stride=scale_factor, padding=3//2,
# output_padding=scale_factor-1)
self.last_part = nn.PixelShuffle(scale_factor)
# 具体一点来说,Pixelshuffle会将shape为(∗,r2C,H,W)(∗,r2C,H,W)的Tensor给reshape成(∗,C,rH,rW)(∗,C,rH,rW)的Tensor。形式化地说,它的输入输出的shape如下:
#- Input: (N,C∗upscale_factor2,H,W)(N,C∗upscale_factor2,H,W)
#- Output: (N,C,H∗upscale_factor,W∗upscale_factor)
self._initialize_weights()
def _initialize_weights(self):
for m in self.first_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
for m in self.mid_part1:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
for m in self.mid_part2:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
for m in self.mid_part3:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
for m in self.mid_part4:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
# nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
# nn.init.zeros_(self.last_part.bias.data)
def forward(self, x):
print(x.size())
out1 = self.first_part(x)
print(out1.size())
temp = torch.cat([out1, x], 1)
print(temp.size())
out2 = self.mid_part1(out1)
print(out2.size())
cat2 = torch.cat([out1, out2], 1)
print(cat2.size())
out3 = self.mid_part2(cat2)
print(out3.size())
cat3 = torch.cat([out1, out2, out3], 1)
print(cat3.size())
out4 = self.mid_part3(cat3)
print(out4.size())
cat4 = torch.cat([out1, out2, out3, out4], 1)
print(cat4.size())
out5 = self.mid_part4(cat4)
print(out5.size())
print(self.mid_part4)
for m in self.mid_part4:
if isinstance(m, nn.Conv2d):
print('weight形状:',m.weight.data.size()) #卷积的权重大小
print(m.bias.data)
cat5 = torch.cat([out1, out2, out3, out4, out5], 1)
print(cat5.size())
out6 = self.mid_part(cat5)
print('out6.size():',out6.size())
m = self.last_part
print(m)
out = self.last_part(out6)
print(out.size())
# print(m.weight.data.size())
# print(m.bias.data)
return out
if __name__ == '__main__':
model = CNN(scale_factor = 3)
print(model)
input = torch.randn(12,1,28,36)
with torch.no_grad():
pre = model(input)
# print(pre)
# print(pre.size())
pred = pre.clamp(0.0, 1.0)
print('pred.size():',pred.size())
print(pred[..., 0].shape)
print(pred[..., 1].shape)
print(pred[..., 2].shape)
# pred = pred.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
# print(pred.shape)
# print(pred[..., 0].shape)
params = sum(p.numel() for p in model.parameters()) #计算模型总参数量
print(params)
结果:
CNN(
(first_part): Sequential(
(0): Conv2d(1, 56, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=56)
)
(mid_part1): Sequential(
(0): Conv2d(56, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=12)
)
(mid_part2): Sequential(
(0): Conv2d(68, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=12)
)
(mid_part3): Sequential(
(0): Conv2d(80, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=12)
)
(mid_part4): Sequential(
(0): Conv2d(92, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=12)
)
(mid_part): Sequential(
(0): Conv2d(104, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=9)
)
(last_part): PixelShuffle(upscale_factor=3)
)
torch.Size([12, 1, 28, 36])
torch.Size([12, 56, 28, 36])
torch.Size([12, 57, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 68, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 80, 28, 36])
torch.Size([12, 12, 28, 36])
torch.Size([12, 92, 28, 36])
torch.Size([12, 12, 28, 36])
Sequential(
(0): Conv2d(92, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=12)
)
weight形状: torch.Size([12, 92, 3, 3])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
torch.Size([12, 104, 28, 36])
out6.size(): torch.Size([12, 9, 28, 36])
PixelShuffle(upscale_factor=3)
torch.Size([12, 1, 84, 108])
pred.size(): torch.Size([12, 1, 84, 108])
torch.Size([12, 1, 84])
torch.Size([12, 1, 84])
torch.Size([12, 1, 84])
41122 #模型总参数量