import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from torchsummary import summary
import torchvision.models as models
#from ..utils import initialize_weights
#from ..utils.misc import Conv2dDeformable
#from .config import res101_path
def initialize_weights(*models):
for model in models:
for module in model.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
class _PyramidPoolingModule(nn.Module):
def __init__(self, in_dim, reduction_dim, setting):
super(_PyramidPoolingModule, self).__init__()
self.features = []
for s in setting:
self.features.append(nn.Sequential(
nn.AdaptiveAvgPool2d(s),
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(reduction_dim, momentum=.95),
nn.ReLU(inplace=True)
))
self.features = nn.ModuleList(self.features)
def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
out = torch.cat(out, 1)
return out
class PSPNet(nn.Module):
def __init__(self, layers=50,num_classes=1, pretrained=True, use_aux=True):
super(PSPNet, self).__init__()
self.use_aux = use_aux
assert layers in [50, 101, 152]
# assert 2048 % len(bins) == 0
# assert classes > 1
# assert zoom_factor in [1, 2, 4, 8]
# self.zoom_factor = zoom_factor
# self.use_ppm = use_ppm
# self.criterion = criterion
# models.BatchNorm = BatchNorm
if layers == 50:
resnet = models.resnet50(pretrained=pretrained)
elif layers == 101:
resnet = models.resnet101(pretrained=pretrained)
else:
resnet = models.resnet152(pretrained=pretrained)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
self.final = nn.Sequential(
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(512, num_classes, kernel_size=1)
)
if use_aux:
self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
initialize_weights(self.aux_logits)
initialize_weights(self.ppm, self.final)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x_size = x.size()
#print(x_size[2:])
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if self.use_aux:
aux = self.aux_logits(x)
x = self.layer4(x)
x = self.ppm(x)
x = self.final(x)
x = F.interpolate(x, size=x_size[2:], mode='bilinear', align_corners=True)
x = self.sigmoid(x)
if self.use_aux:
aux = F.interpolate(aux, size=x_size[2:], mode='bilinear', align_corners=True)
aux = self.sigmoid(aux)
return x ,aux
return x
if __name__ == '__main__':
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = '0'
#input = torch.rand(4, 3, 473, 473).cuda()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PSPNet(layers=50, num_classes=1, pretrained=False,use_aux=False).to(device)
summary(model, (3, 256, 256), batch_size=6 , device="cuda")
model.eval()
#print(model)
#output = model(input)
#print('PSPNet', output.size())
pspnet
最新推荐文章于 2023-03-25 17:30:02 发布