import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
##############################
# U-NET
##############################
import torch
import torch.nn as nn
from torchvision import models
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
class BasicBlock(nn.Module):
"""Basic Block for resnet 18 and resnet 34
"""
# BasicBlock and BottleNeck block
# have different output size
# we use class attribute expansion
# to distinct
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
# residual function
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
# shortcut
self.shortcut = nn.Sequential()
# the shortcut output dimension is not the same with residual function
# use 1*1 convolution to match the dimension
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
"""Residual block for resnet over 50 layers
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self,in_channel,out_channel, block, num_block):
super().__init__()
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, 64, kernel_size = 7, stride = 2, padding = 3,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# we use a different inputsize than the original paper
# so conv2_x's stride is 1
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
# self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
# self.fc = nn.Linear(512 * block.expansion, num_classes)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.dconv_up3 = double_conv(256 + 512, 256)
self.dconv_up2 = double_conv(128 + 256, 128)
self.dconv_up1 = double_conv(128 + 64, 64)
self.dconv_last=nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(64,out_channel,1)
)
def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the
same as a neuron netowork layer, ex. conv layer), one layer may
contain more than one residual block
Args:
block: block type, basic block or bottle neck block
out_channels: output depth channel number of this layer
num_blocks: how many blocks per layer
stride: the stride of the first block of this layer
Return:
return a resnet layer
"""
# we have num_block blocks per layer, the first block
# could be 1 or 2, other blocks would always be 1
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
conv1 = self.conv1(x)
temp=self.maxpool(conv1)
conv2 = self.conv2_x(temp)
conv3 = self.conv3_x(conv2)
conv4 = self.conv4_x(conv3)
bottle = self.conv5_x(conv4)
# output = self.avg_pool(output)
# output = output.view(output.size(0), -1)
# output = self.fc(output)
x = self.upsample(bottle)
# print(x.shape)
# print(conv4.shape)
x = torch.cat([x, conv4], dim=1)
x = self.dconv_up3(x)
x = self.upsample(x)
# print(x.shape)
# print(conv3.shape)
x = torch.cat([x, conv3], dim=1)
x = self.dconv_up2(x)
x = self.upsample(x)
# print(x.shape)
# print(conv2.shape)
x = torch.cat([x, conv2], dim=1)
x = self.dconv_up1(x)
x=self.upsample(x)
# print(x.shape)
# print(conv1.shape)
x=torch.cat([x,conv1],dim=1)
out=self.dconv_last(x)
return out
def load_pretrained_weights(self):
model_dict=self.state_dict()
resnet34_weights = models.resnet34(True).state_dict()
count_res = 0
count_my = 0
reskeys = list(resnet34_weights.keys())
mykeys = list(model_dict.keys())
# print(self)
# print(models.resnet34())
# print(reskeys)
# print(mykeys)
corresp_map = []
while (True): # 后缀相同的放入list
reskey = reskeys[count_res]
mykey = mykeys[count_my]
if "fc" in reskey:
break
while reskey.split(".")[-1] not in mykey:
count_my += 1
mykey = mykeys[count_my]
corresp_map.append([reskey, mykey])
count_res += 1
count_my += 1
for k_res, k_my in corresp_map:
model_dict[k_my]=resnet34_weights[k_res]
try:
self.load_state_dict(model_dict)
print("Loaded resnet34 weights in mynet !")
except:
print("Error resnet34 weights in mynet !")
raise
def resnet18():
""" return a ResNet 18 object
"""
return ResNet(BasicBlock, [2, 2, 2, 2])
def resnet34(in_channel,out_channel,pretrain=True):
""" return a ResNet 34 object
"""
model=ResNet(in_channel,out_channel,BasicBlock, [3, 4, 6, 3])
if pretrain:
model.load_pretrained_weights()
return model
def resnet50():
""" return a ResNet 50 object
"""
return ResNet(BottleNeck, [3, 4, 6, 3])
def resnet101():
""" return a ResNet 101 object
"""
return ResNet(BottleNeck, [3, 4, 23, 3])
def resnet152():
""" return a ResNet 152 object
"""
return ResNet(BottleNeck, [3, 8, 36, 3])
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels * 2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1, bias=False)
)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)
if __name__ == '__main__':
net = resnet34(3, 3, True)
print(net)
x = torch.rand((1, 3, 512, 512))
print(net.forward(x).shape)
# a = torch.randn(1,3,512,512)
# d = torch.randn(1,3,512,512)
#
# dis = Discriminator()
#
# c = dis(d)
#
# print((c.shape))
效果不好,并不知道错在哪里,训练200epoch后效果图如下