# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 16:55:26 2021
@author: Administrator
"""
import os
import numpy as np
from torchvision import transforms as tfs
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
import cv2
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import torch.nn as nn
from torchvision import models
import time
import torch.nn.functional as F
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
os.environ['KMP_DUPLICATE_LIB_OK']='True'
torch.backends.cudnn.enable =True
torch.backends.cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
inria_root = r'E:\INRIA_DATASET'
def read_images(root=inria_root, train=True):
if train == True:
data = [os.path.join(root, 'train', 'Image_batch' , '{:0>6d}.tif'.format(i)) for i in list(range(50000))]
print('train_data finish')
label = [os.path.join(root, 'train', 'label_batch' , '{:0>6d}.tif'.format(i)) for i in list(range(50000))]
print('train_label finish')
else:
data = [os.path.join(root, 'val', 'Image_batch' , '{:0>6d}.tif'.format(i+50000)) for i in list(range(14980))]
print('val_data finish')
label = [os.path.join(root, 'val', 'label_batch' , '{:0>6d}.tif'.format(i+50000)) for i in list(range(14980))]
print('val_label finish')
return data, label
classes = ['background', 'building']
colormap = [[0,0,0],[255,255,255]]
len(classes), len(colormap)
cm2lbl = np.zeros(256**3) # 每个像素点有 0 ~ 255 的选择,RGB 三个通道
for i,cm in enumerate(colormap):
cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i # 建立索引
def image2label(im):
data = np.array(im, dtype='int64')
idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
return np.array(cm2lbl[idx], dtype='int64') # 根据索引得到 label 矩阵
def img_transforms(im,label):
# im,label=rand_crop(im,label,*crop_size)
im_tfs=tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
im=im_tfs(im)
label=image2label(label)
label=torch.from_numpy(label).long()
return im,label
class VOCSegDataset(Dataset):
def __init__(self,train,transforms):
self.transforms=transforms
data_list,label_list=read_images(train=train)
self.data_list = data_list
self.label_list = label_list
print('Read'+str(len(self.data_list))+'images')
def __getitem__(self,idx):
img=self.data_list[idx]
label=self.label_list[idx]
# img=cv2.imread(img)
# img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
# label=cv2.imread(label)
# label=cv2.cvtColor(label,cv2.COLOR_BGR2RGB)
img = Image.open(img)
label = Image.open(label).convert('RGB')
img,label=self.transforms(img,label)
return img,label
def __len__(self):
return len(self.data_list)
voc_train=VOCSegDataset(True,img_transforms)
voc_test=VOCSegDataset(False,img_transforms)
train_data=DataLoader(voc_train,batch_size=16,shuffle=True)
valid_data=DataLoader(voc_test,batch_size=16)
num_classes=len(classes)
class ResNet(nn.Module):
def __init__(self, backbone='resnet50', pretrained_path=None):
super().__init__()
if backbone == 'resnet18':
backbone = resnet18(pretrained=not pretrained_path)
self.final_out_channels = 256
self.low_level_inplanes = 64
elif backbone == 'resnet34':
backbone = resnet34(pretrained=not pretrained_path)
self.final_out_channels = 256
self.low_level_inplanes = 64
elif backbone == 'resnet50':
backbone = resnet50(pretrained=not pretrained_path)
self.final_out_channels = 1024
self.low_level_inplanes = 256
elif backbone == 'resnet101':
backbone = resnet101(pretrained=not pretrained_path)
self.final_out_channels = 1024
self.low_level_inplanes = 256
else: # backbone == 'resnet152':
backbone = resnet152(pretrained=not pretrained_path)
self.final_out_channels = 1024
self.low_level_inplanes = 256
if pretrained_path:
backbone.load_state_dict(torch.load(pretrained_path))
self.early_extractor = nn.Sequential(*list(backbone.children())[:5])
self.later_extractor = nn.Sequential(*list(backbone.children())[5:7])
conv4_block1 = self.later_extractor[-1][0]
conv4_block1.conv1.stride = (1, 1)
conv4_block1.conv2.stride = (1, 1)
conv4_block1.downsample[0].stride = (1, 1)
def forward(self, x):
x = self.early_extractor(x)
out = self.later_extractor(x)
return out,x
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
stride=1, padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self._init_weight()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class ASPP(nn.Module):
def __init__(self, inplanes=2048, output_stride=16):
super(ASPP, self).__init__()
if output_stride == 16:
dilations = [1, 6, 12, 18]
elif output_stride == 8:
dilations = [1, 12, 24, 36]
else:
raise NotImplementedError
self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0])
self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1])
self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2])
self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
self.bn1 = nn.BatchNorm2d(256)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(0.5)
self._init_weight()
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class Decoder(nn.Module):
def __init__(self, num_classes, low_level_inplanes=256):
super(Decoder, self).__init__()
self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
self.bn1 = nn.BatchNorm2d(48)
self.relu = nn.ReLU()
self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
self._init_weight()
def forward(self, x, low_level_feat):
low_level_feat = self.conv1(low_level_feat)
low_level_feat = self.bn1(low_level_feat)
low_level_feat = self.relu(low_level_feat)
x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x, low_level_feat), dim=1)
x = self.last_conv(x)
return x
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class DeepLabv3Plus(nn.Module):
def __init__(self, num_classes=None):
super().__init__()
self.num_classes = num_classes
self.backbone = ResNet('resnet50', None)
self.aspp = ASPP(inplanes=self.backbone.final_out_channels)
self.decoder = Decoder(self.num_classes, self.backbone.low_level_inplanes)
def forward(self, imgs, labels=None, mode='infer', **kwargs):
x, low_level_feat = self.backbone(imgs)
x = self.aspp(x)
x = self.decoder(x, low_level_feat)
outputs = F.interpolate(x, size=imgs.size()[2:], mode='bilinear', align_corners=True)
return outputs
def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def label_accuracy_score(label_trues, label_preds, n_class):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
"""
hist = np.zeros((n_class, n_class))
for lt, lp in zip(label_trues, label_preds):
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
acc = np.diag(hist).sum() / (hist.sum() + 1e-6)
acc_cls = np.diag(hist) / (hist.sum(axis=1) + 1e-6)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / ((hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + 1e-6)
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / (hist.sum() + 1e-6)
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
return acc, acc_cls, mean_iu, fwavacc
criterion = nn.CrossEntropyLoss().to(device)
net=DeepLabv3Plus(num_classes = 2)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2, weight_decay=1e-4)
Loss_list = []
Accuracy_list = []
for e in range(50):
train_loss = 0
train_acc = 0
train_acc_cls = 0
train_mean_iu = 0
train_fwavacc = 0
prev_time = time.time()
net = net.train()
net = net.to(device)
for data in train_data:
im = data[0].to(device)
label = data[1].to(device)
# print('label-shape:', label.shape)
# forward
out = net(im)
# print('out-shape:', out.shape)
# out = F.log_softmax(out, dim=1) # (b, n, h, w)
loss = criterion(out, label.squeeze())
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
label_pred = out.max(dim=1)[1].data.cpu().numpy()
label_true = label.data.cpu().numpy()
for lbt, lbp in zip(label_true, label_pred):
acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
train_acc += acc
train_acc_cls += acc_cls
train_mean_iu += mean_iu
train_fwavacc += fwavacc
net = net.eval()
eval_loss = 0
eval_acc = 0
eval_acc_cls = 0
eval_mean_iu = 0
eval_fwavacc = 0
for data in valid_data:
im = data[0].to(device)
label = data[1].to(device)
# forward
out = net(im)
# out = F.log_softmax(out, dim=1)
loss = criterion(out, label)
eval_loss += loss.item()
label_pred = out.max(dim=1)[1].data.cpu().numpy()
label_true = label.data.cpu().numpy()
for lbt, lbp in zip(label_true, label_pred):
acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
eval_acc += acc
eval_acc_cls += acc_cls
eval_mean_iu += mean_iu
eval_fwavacc += fwavacc
cur_time = time.time()
epoch_str = ('Epoch: {}, Train Loss: {:.5f}, Train Acc: {:.5f}, Train Mean IU: {:.5f}, \
Valid Loss: {:.5f}, Valid Acc: {:.5f}, Valid Mean IU: {:.5f} '.format(
e, train_loss / len(train_data), train_acc / len(voc_train), train_mean_iu / len(voc_train),
eval_loss / len(valid_data), eval_acc / len(voc_test), eval_mean_iu / len(voc_test)))
print(epoch_str)
Loss_list.append(train_loss / (len(train_data)))
Accuracy_list.append(100 * train_acc / (len(voc_train)))
torch.cuda.empty_cache()
x1 = range(0, 50)
x2 = range(0, 50)
y1 = Accuracy_list
y2 = Loss_list
plt.subplot(2, 1, 1)
plt.plot(x1, y1, 'o-')
plt.title('Test accuracy vs. epoches')
plt.ylabel('Test accuracy')
plt.subplot(2, 1, 2)
plt.plot(x2, y2, '.-')
plt.xlabel('Test loss vs. epoches')
plt.ylabel('Test loss')
plt.show()
plt.savefig("accuracy_loss.jpg")
torch.save(net.state_dict(),r'E:\INRIA_DATASET\deeplabv3+_resnet50.pth')
net=DeepLabv3Plus(num_classes = 2)
net.load_state_dict(torch.load('E:\INRIA_DATASET\deeplabv3+_resnet50.pth'))
net=net.to(device)
net.eval()
# 定义预测函数
cm = np.array(colormap).astype('uint8')
def predict(im, label): # 预测结果
im = im.unsqueeze(0).to(device)
out = net(im)
pred = out.max(1)[1].squeeze().cpu().data.numpy()
pred = cm[pred]
return pred, cm[label.numpy()]
for i in range(100):
test_data, test_label = voc_test[i]
pred, label = predict(test_data, test_label)
cv2.imwrite(r'E:\INRIA_DATASET\result\{}.jpg'.format(i),pred)
2021-07-05
最新推荐文章于 2022-08-17 17:32:33 发布