2021-07-05

# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 16:55:26 2021

@author: Administrator
"""

import os
import numpy as np
from torchvision import transforms as tfs
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
import cv2
from torch.utils.data import DataLoader
import  matplotlib.pyplot as plt
import os
import  torch.nn as nn
from torchvision import models
import time
import torch.nn.functional as F
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


os.environ['KMP_DUPLICATE_LIB_OK']='True'

torch.backends.cudnn.enable =True
torch.backends.cudnn.benchmark = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')



inria_root = r'E:\INRIA_DATASET'


def read_images(root=inria_root, train=True):
    if train == True:
        data = [os.path.join(root, 'train', 'Image_batch' , '{:0>6d}.tif'.format(i)) for i in list(range(50000))]
        print('train_data finish')
        label = [os.path.join(root, 'train', 'label_batch' , '{:0>6d}.tif'.format(i)) for i in list(range(50000))]
        print('train_label finish')
    else:
        data = [os.path.join(root, 'val', 'Image_batch' , '{:0>6d}.tif'.format(i+50000)) for i in list(range(14980))]
        print('val_data finish')
        label = [os.path.join(root, 'val', 'label_batch' , '{:0>6d}.tif'.format(i+50000)) for i in list(range(14980))]
        print('val_label finish')        
    return data, label


classes = ['background', 'building']


colormap = [[0,0,0],[255,255,255]]

len(classes), len(colormap)

cm2lbl = np.zeros(256**3) # 每个像素点有 0 ~ 255 的选择,RGB 三个通道
for i,cm in enumerate(colormap):
    cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i # 建立索引

def image2label(im):
    data = np.array(im, dtype='int64')
    idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
    return np.array(cm2lbl[idx], dtype='int64') # 根据索引得到 label 矩阵


def img_transforms(im,label):
    # im,label=rand_crop(im,label,*crop_size)
    im_tfs=tfs.Compose([
        tfs.ToTensor(),
        tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    im=im_tfs(im)
    label=image2label(label)
    label=torch.from_numpy(label).long()
    return im,label


class VOCSegDataset(Dataset):
    def __init__(self,train,transforms):
        self.transforms=transforms
        data_list,label_list=read_images(train=train)
        self.data_list = data_list
        self.label_list = label_list
        print('Read'+str(len(self.data_list))+'images')

    def __getitem__(self,idx):
        img=self.data_list[idx]
        label=self.label_list[idx]
        # img=cv2.imread(img)
        # img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        # label=cv2.imread(label)
        # label=cv2.cvtColor(label,cv2.COLOR_BGR2RGB)
        img = Image.open(img)
        label = Image.open(label).convert('RGB')
        img,label=self.transforms(img,label)
        return img,label
    
    def __len__(self):
        return len(self.data_list)

voc_train=VOCSegDataset(True,img_transforms)
voc_test=VOCSegDataset(False,img_transforms)


train_data=DataLoader(voc_train,batch_size=16,shuffle=True)
valid_data=DataLoader(voc_test,batch_size=16)

num_classes=len(classes)

class ResNet(nn.Module):
    def __init__(self, backbone='resnet50', pretrained_path=None):
        super().__init__()
        if backbone == 'resnet18':
            backbone = resnet18(pretrained=not pretrained_path)
            self.final_out_channels = 256
            self.low_level_inplanes = 64
        elif backbone == 'resnet34':
            backbone = resnet34(pretrained=not pretrained_path)
            self.final_out_channels = 256
            self.low_level_inplanes = 64
        elif backbone == 'resnet50':
            backbone = resnet50(pretrained=not pretrained_path)
            self.final_out_channels = 1024
            self.low_level_inplanes = 256
        elif backbone == 'resnet101':
            backbone = resnet101(pretrained=not pretrained_path)
            self.final_out_channels = 1024
            self.low_level_inplanes = 256
        else:  # backbone == 'resnet152':
            backbone = resnet152(pretrained=not pretrained_path)
            self.final_out_channels = 1024
            self.low_level_inplanes = 256
        if pretrained_path:
            backbone.load_state_dict(torch.load(pretrained_path))


        self.early_extractor = nn.Sequential(*list(backbone.children())[:5])
        self.later_extractor = nn.Sequential(*list(backbone.children())[5:7])

        conv4_block1 = self.later_extractor[-1][0]

        conv4_block1.conv1.stride = (1, 1)
        conv4_block1.conv2.stride = (1, 1)
        conv4_block1.downsample[0].stride = (1, 1)

    def forward(self, x):
        x = self.early_extractor(x)
        out = self.later_extractor(x)
        return out,x

class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation):
        super(_ASPPModule, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, inplanes=2048, output_stride=16):
        super(ASPP, self).__init__()
        if output_stride == 16:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 12, 24, 36]
        else:
            raise NotImplementedError

        self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0])
        self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1])
        self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2])
        self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3])

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
                                             nn.BatchNorm2d(256),
                                             nn.ReLU(inplace=True))
        self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(256)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        self._init_weight()

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        return self.dropout(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                torch.nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

class Decoder(nn.Module):
    def __init__(self, num_classes, low_level_inplanes=256):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       nn.ReLU(inplace=True),
                                       nn.Dropout(0.5),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       nn.ReLU(inplace=True),
                                       nn.Dropout(0.1),
                                       nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
        self._init_weight()


    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_feat), dim=1)
        x = self.last_conv(x)
        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                torch.nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


class DeepLabv3Plus(nn.Module):
    def __init__(self, num_classes=None):
        super().__init__()
        self.num_classes = num_classes

        self.backbone = ResNet('resnet50', None)
        self.aspp = ASPP(inplanes=self.backbone.final_out_channels)
        self.decoder = Decoder(self.num_classes, self.backbone.low_level_inplanes)

    def forward(self, imgs, labels=None, mode='infer', **kwargs):
        x, low_level_feat = self.backbone(imgs)
        x = self.aspp(x)
        x = self.decoder(x, low_level_feat)
        outputs = F.interpolate(x, size=imgs.size()[2:], mode='bilinear', align_corners=True)
        return outputs


    
def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist


def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
      - fwavacc
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
             hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / (hist.sum() + 1e-6)
    acc_cls = np.diag(hist) / (hist.sum(axis=1) + 1e-6)
    acc_cls = np.nanmean(acc_cls)
    iu = np.diag(hist) / ((hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + 1e-6)
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / (hist.sum() + 1e-6)
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
    return acc, acc_cls, mean_iu, fwavacc


criterion = nn.CrossEntropyLoss().to(device)

net=DeepLabv3Plus(num_classes = 2)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2, weight_decay=1e-4)


Loss_list = []
Accuracy_list = []


for e in range(50):
    
    train_loss = 0
    train_acc = 0
    train_acc_cls = 0
    train_mean_iu = 0
    train_fwavacc = 0
    
    prev_time = time.time()
    net = net.train()
    net = net.to(device)
    for data in train_data:
        im = data[0].to(device)
        label = data[1].to(device)
        # print('label-shape:', label.shape)
        # forward
        out = net(im)
        # print('out-shape:', out.shape)
        # out = F.log_softmax(out, dim=1) # (b, n, h, w)
        loss = criterion(out, label.squeeze())
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        label_pred = out.max(dim=1)[1].data.cpu().numpy()
        label_true = label.data.cpu().numpy()
        for lbt, lbp in zip(label_true, label_pred):
            acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
            train_acc += acc
            train_acc_cls += acc_cls
            train_mean_iu += mean_iu
            train_fwavacc += fwavacc
    
    net = net.eval()
    eval_loss = 0
    eval_acc = 0
    eval_acc_cls = 0
    eval_mean_iu = 0
    eval_fwavacc = 0
    for data in valid_data:
        im = data[0].to(device)
        label = data[1].to(device)
        # forward
        out = net(im)
        # out = F.log_softmax(out, dim=1)
        loss = criterion(out, label)
        eval_loss += loss.item()
        label_pred = out.max(dim=1)[1].data.cpu().numpy()
        label_true = label.data.cpu().numpy()
        for lbt, lbp in zip(label_true, label_pred):
            acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
            eval_acc += acc
            eval_acc_cls += acc_cls
            eval_mean_iu += mean_iu
            eval_fwavacc += fwavacc
        
    cur_time = time.time()

    epoch_str = ('Epoch: {}, Train Loss: {:.5f}, Train Acc: {:.5f}, Train Mean IU: {:.5f}, \
Valid Loss: {:.5f}, Valid Acc: {:.5f}, Valid Mean IU: {:.5f} '.format(
        e, train_loss / len(train_data), train_acc / len(voc_train), train_mean_iu / len(voc_train),
        eval_loss / len(valid_data), eval_acc / len(voc_test), eval_mean_iu / len(voc_test)))
   
    print(epoch_str)
        
    Loss_list.append(train_loss / (len(train_data)))
    Accuracy_list.append(100 * train_acc / (len(voc_train)))
      
    torch.cuda.empty_cache()


x1 = range(0, 50)
x2 = range(0, 50)
y1 = Accuracy_list
y2 = Loss_list
plt.subplot(2, 1, 1)
plt.plot(x1, y1, 'o-')
plt.title('Test accuracy vs. epoches')
plt.ylabel('Test accuracy')
plt.subplot(2, 1, 2)
plt.plot(x2, y2, '.-')
plt.xlabel('Test loss vs. epoches')
plt.ylabel('Test loss')
plt.show()
plt.savefig("accuracy_loss.jpg")

torch.save(net.state_dict(),r'E:\INRIA_DATASET\deeplabv3+_resnet50.pth')

net=DeepLabv3Plus(num_classes = 2)

net.load_state_dict(torch.load('E:\INRIA_DATASET\deeplabv3+_resnet50.pth'))
net=net.to(device)
net.eval()


# 定义预测函数
cm = np.array(colormap).astype('uint8')

def predict(im, label): # 预测结果
    im = im.unsqueeze(0).to(device)
    out = net(im)
    pred = out.max(1)[1].squeeze().cpu().data.numpy()
    pred = cm[pred]
    return pred, cm[label.numpy()]
    
for i in range(100):
    test_data, test_label = voc_test[i]
    pred, label = predict(test_data, test_label)
    cv2.imwrite(r'E:\INRIA_DATASET\result\{}.jpg'.format(i),pred)
    

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值