一、裁剪 删除错误的DSM图像 删除空白标签
import cv2
import os
# Cutting the input image to h*w blocks
inPath = "./dataset/sat_train/"
outPath = "./dataset/train/"
inPath2 = "./dataset/mask_train/"
for f in os.listdir(inPath):
path = inPath + f.strip()
print(path)
img = cv2.imread(path)
height = img.shape[0]
width = img.shape[1]
# The size of block that you want to cut
heightBlock = 512
widthBlock = 512
heightCutNum = int(height / heightBlock)
widthCutNum = int(width / widthBlock)
l = 0
for i in range(0,heightCutNum):
for j in range(0,widthCutNum):
cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_sat.tif'.format(i, j, l)
l+=1
cv2.imwrite(savePath,cutImage)
print(savePath)
for f in os.listdir(inPath2):
path = inPath2 + f.strip()
print(path)
img = cv2.imread(path)
height = img.shape[0]
width = img.shape[1]
# The size of block that you want to cut
heightBlock = 512
widthBlock = 512
heightCutNum = int(height / heightBlock)
widthCutNum = int(width / widthBlock)
l = 0
for i in range(0,heightCutNum):
for j in range(0,widthCutNum):
cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_mask.png'.format(i, j, l)
l+=1
cv2.imwrite(savePath,cutImage)
print(savePath)
print("finish!")
mask_names = filter(lambda x: x.find('mask')!=-1, os.listdir(outPath))
# sat_names = filter(lambda x: x.find('sat')!=-1, os.listdir(tar))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in mask_names:
path = outPath + f.strip()
if not os.path.exists(path):
continue;
img = cv2.imread(path,0)
if cv2.countNonZero(img) == 0:
print(f+'Image is black')
path2=f[:-9]
os.remove(path)
os.remove(outPath +path2 + "_sat.tif")
二、fenxi
import os
import shutil
data_path='./submits/log01_Dink101_five_100/test_iou/'
data=open(os.path.join(data_path, "log01_Dink101_five_100_excel.txt"),'r').read().splitlines()
valid_path='./dataset/valid/'
rgb_path='./dataset/valid_all/'
real_path='./dataset/real/'
iou_100=os.path.join(data_path,'iou_100/')
iou_80=os.path.join(data_path,'iou_80/')
iou_50=os.path.join(data_path,'iou_50/')
iou_30=os.path.join(data_path,'iou_30/')
if not os.path.exists(iou_100):
os.mkdir(iou_100)
os.mkdir(iou_80)
os.mkdir(iou_50)
os.mkdir(iou_30)
for n in data:
name=n.split()[1]
iou=float(n.split()[2])
img_path=os.path.join(data_path,'test_pre_img/'+name+'.png')
valid_name=os.path.join(valid_path,name[:-4]+'sat.tif')
rgb_name=os.path.join(rgb_path,name[:-4]+'sat.tif')
real_name=os.path.join(real_path,name[:-4]+'mask.png')
if iou>=80:
shutil.copy(img_path,iou_100)
file_name=os.path.join(iou_100,name+'.png')
new_name=os.path.join(iou_100,name[:-4]+'tmask_'+str(iou)+'.png')
os.rename(file_name,new_name)
shutil.copy(rgb_name,iou_100)
file_name=os.path.join(iou_100,name[:-4]+'sat.tif')
new_name=os.path.join(iou_100,name[:-4]+'rgb.tif')
os.rename(file_name,new_name)
shutil.copy(valid_name,iou_100)
shutil.copy(real_name,iou_100)
file_name=os.path.join(iou_100,name[:-4]+'mask.png')
new_name=os.path.join(iou_100,name[:-4]+'tmask.png')
os.rename(file_name,new_name)
print(name,iou)
continue
elif iou>=50:
shutil.copy(img_path,iou_80)
file_name=os.path.join(iou_80,name+'.png')
new_name=os.path.join(iou_80,name[:-4]+'tmask_'+str(iou)+'.png')
os.rename(file_name,new_name)
shutil.copy(rgb_name,iou_80)
file_name=os.path.join(iou_80,name[:-4]+'sat.tif')
new_name=os.path.join(iou_80,name[:-4]+'rgb.tif')
os.rename(file_name,new_name)
shutil.copy(valid_name,iou_80)
shutil.copy(real_name,iou_80)
file_name=os.path.join(iou_80,name[:-4]+'mask.png')
new_name=os.path.join(iou_80,name[:-4]+'tmask.png')
os.rename(file_name,new_name)
print(name,iou)
continue
elif iou>=30:
shutil.copy(img_path,iou_50)
file_name=os.path.join(iou_50,name+'.png')
new_name=os.path.join(iou_50,name[:-4]+'tmask_'+str(iou)+'.png')
os.rename(file_name,new_name)
shutil.copy(rgb_name,iou_50)
file_name=os.path.join(iou_50,name[:-4]+'sat.tif')
new_name=os.path.join(iou_50,name[:-4]+'rgb.tif')
os.rename(file_name,new_name)
shutil.copy(valid_name,iou_50)
shutil.copy(real_name,iou_50)
file_name=os.path.join(iou_50,name[:-4]+'mask.png')
new_name=os.path.join(iou_50,name[:-4]+'tmask.png')
os.rename(file_name,new_name)
print(name,iou)
continue
else:
shutil.copy(img_path,iou_30)
file_name=os.path.join(iou_30,name+'.png')
new_name=os.path.join(iou_30,name[:-4]+'tmask_'+str(iou)+'.png')
os.rename(file_name,new_name)
shutil.copy(rgb_name,iou_30)
file_name=os.path.join(iou_30,name[:-4]+'sat.tif')
new_name=os.path.join(iou_30,name[:-4]+'rgb.tif')
os.rename(file_name,new_name)
shutil.copy(valid_name,iou_30)
shutil.copy(real_name,iou_30)
file_name=os.path.join(iou_30,name[:-4]+'mask.png')
new_name=os.path.join(iou_30,name[:-4]+'tmask.png')
os.rename(file_name,new_name)
print(name,iou)
continue
print('Finish')
3 select iou>30
import os
import shutil
data_path='./submits/log01_Dink101_five_100/test_iou/'
data=open(os.path.join(data_path, "log01_Dink101_five_100_excel.txt"),'r').read().splitlines()
iou_100=os.path.join(data_path,'test_pre_img/')
if not os.path.exists(iou_100):
os.mkdir(iou_100)
for n in data:
name=n.split()[1]
iou=float(n.split()[2])
img_path=os.path.join(data_path,'test_pre_img87.24/'+name+'.png')
if iou>=30:
shutil.copy(img_path,iou_100)
print(name,iou)
continue
print('Finish')
delete real
import os
import cv2
# source = 'dataset/sat_train/'
real_path ="./dataset/real/"
pre_path ="./submits/log01_Dink101_five_100/test_iou/test_pre_img/"
real_names = filter(lambda x: x.find('mask')!=-1, os.listdir(real_path))
pre_names = filter(lambda x: x.find('mask')!=-1, os.listdir(pre_path))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in real_names:
pre_name = pre_path + f.strip()
if not os.path.exists(pre_name):
os.remove(real_path + f.strip())
print(real_path + f.strip())
# for f in sat_names:
# mask_path = tar + f.strip()[:-8] + "_mask.png"
# if not os.path.exists(mask_path):
# os.remove(tar + f.strip())
# print(tar + f.strip())
predict
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔
from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
BATCHSIZE_PER_CARD = 32
# class TTAFrame():
# def __init__(self, net):
# self.net = net().cuda()
# self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
# def load(self, path):
# new_state_dict = OrderedDict()
# for key, value in torch.load(path).items():
# name = 'module.' + key
# new_state_dict[name] = value
#model.load_state_dict(new_state_dict)
#model = torch.load(path)
#model.pop('module.finaldeconv1.weight')
#model.pop('module.finalconv3.weight')
#self.net.load_state_dict(model,strict=False)
# self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'
def saveList(pathName):
for file_name in pathName:
#f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
with open("./dataset/gt.txt", "a") as f:
f.write(file_name.split(".")[0] + "\n")
f.close
def savetrainList(pathName):
for file_name in pathName:
#f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
with open("./dataset/gt_train.txt", "a") as f:
f.write(file_name.split(".")[0] + "\n")
f.close
def dirList(gt_dir,path_list):
for i in range(0, len(path_list)):
path = os.path.join(gt_dir, path_list[i])
if os.path.isdir(path):
saveList(os.listdir(path))
print("开始运行!")
# source = 'dataset/test/'
# solver = TTAFrame(DinkNet34)
# solver = TTAFrame(DinkNet50)
weight_dir = "./weights/"
weight_list = os.listdir(weight_dir)
weight_list.sort(key=lambda x:int(x[19:-3]))
save_valid_dir='./dataset/valid_train/'
test_num=len(os.listdir('./dataset/valid/'))
# if not os.path.exists(save_valid_dir):
# trainsample('./dataset/train/',test_num)
mylog = open('submits/count_low_pic.log','w')
source = 'dataset/valid/'
test_valid = os.listdir(source)
test_num=len(os.listdir('./dataset/valid/'))
for weight_name in weight_list:
weight_path=os.path.join(weight_dir,weight_name )
# solver.load('weights/log01_Dink34.th')
# solver.load(weight_path)
tic = time()
tar=os.path.join('./submits/',weight_name[:-3])
target = os.path.join('./submits/',weight_name[:-3]+'/'+'test_iou/')
lower_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'lower_iou/')
higher_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'higher_iou/')
test_pre_img_dir=os.path.join(target,'test_pre_img/')
#wtn:精度计算
miou_mode = 2
#------------------------------#
# 分类个数+1、如2+1
#------------------------------#
num_classes = 2
#--------------------------------------------#
# 区分的种类,和json_to_dataset里面的一样
#--------------------------------------------#
# name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
name_classes = ["nonwater","water"]
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
data_path = './dataset/'
data_train_path='./dataset/'
f=open("./dataset/gt.txt", 'w')
gt_dir = os.path.join(data_path, "real/")
pred_dir = test_pre_img_dir
path_list = os.listdir(gt_dir)
path_list.sort()
dirList(gt_dir,path_list)
saveList(path_list)
image_ids = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines()
train_mIou=[]
train_mPA=[]
test_mIou=[]
test_mPA=[]
if miou_mode == 0 or miou_mode == 2:
mylog.write(str(weight_name[:-3]))
print(weight_name +" Get miou.")
print('计算测试miou')
test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name) # 执行计算mIoU的函数
mylog.write(' test_mIoU: '+str(test_miou))
mylog.write(' test_mPA: '+str(test_mpa))
print(' test_mIoU: '+str(test_miou))
count=0
print('计算测试样本单张iou')
count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count) # 执行计算mIoU的函数
mylog.write(' low-iou test picture num: '+str(count))
print(weight_name + "Get miou done.")
mylog.write('Finish!')
print ('Finish!')
mylog.close()
闭运算
先膨胀后腐蚀 | 用于排除前景对象中的小孔或对象上的小黑点 |
import cv2 as cv
import numpy as np
import os
pre_path='/mnt/sdb1/fenghaixia/dsm/test_pre_img89.49/'
for f in os.listdir(pre_path):
image = cv.imread(pre_path+f)
k = np.ones((5, 5), np.uint8)
open = cv.morphologyEx(image, cv.MORPH_OPEN, k)
cv.imwrite('/mnt/sdb1/fenghaixia/dsm/submits/log01_Dink101_five_100/test_iou/1/' + f, open)
close = cv.morphologyEx(image, cv.MORPH_CLOSE, k)
cv.imwrite('/mnt/sdb1/fenghaixia/dsm/submits/log01_Dink101_five_100/test_iou/2/' + f, close)
close = cv.morphologyEx(open, cv.MORPH_CLOSE, k)
cv.imwrite('/mnt/sdb1/fenghaixia/dsm/submits/log01_Dink101_five_100/test_iou/3/' + f, close)
print('fnish')
只算iou
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔
from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
BATCHSIZE_PER_CARD = 16
class TTAFrame():
def __init__(self, net):
self.net = net().cuda()
self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
def test_one_img_from_path(self, path, evalmode = True):
if evalmode:
self.net.eval()
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
if batchsize >= 8:
return self.test_one_img_from_path_1(path)
def test_one_img_from_path_1(self, path):
img = cv2.imread(path)#.transpose(2,0,1)[None]
img90 = np.array(np.rot90(img))
img1 = np.concatenate([img[None],img90[None]])
img2 = np.array(img1)[:,::-1]
img3 = np.concatenate([img1,img2])
img4 = np.array(img3)[:,:,::-1]
img5 = np.concatenate([img3,img4]).transpose(0,3,1,2)
img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
img5 = V(torch.Tensor(img5).cuda())
mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
mask1 = mask[:4] + mask[4:,:,::-1]
mask2 = mask1[:2] + mask1[2:,::-1]
mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
return mask3
def load(self, path):
# new_state_dict = OrderedDict()
# for key, value in torch.load(path).items():
# name = 'module.' + key
# new_state_dict[name] = value
#model.load_state_dict(new_state_dict)
#model = torch.load(path)
#model.pop('module.finaldeconv1.weight')
#model.pop('module.finalconv3.weight')
#self.net.load_state_dict(model,strict=False)
self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'
def saveList(pathName):
for file_name in pathName:
#f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
with open("./dataset/gt.txt", "a") as f:
f.write(file_name.split(".")[0] + "\n")
f.close
def savetrainList(pathName):
for file_name in pathName:
#f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
with open("./dataset/gt_train.txt", "a") as f:
f.write(file_name.split(".")[0] + "\n")
f.close
def dirList(gt_dir,path_list):
for i in range(0, len(path_list)):
path = os.path.join(gt_dir, path_list[i])
if os.path.isdir(path):
saveList(os.listdir(path))
print("开始运行!")
weight_dir = "./weights/"
weight_list = os.listdir(weight_dir)
weight_list.sort(key=lambda x:int(x[19:-3]))
mylog = open('submits/count_low_pic.log','w')
for weight_name in weight_list:
weight_path=os.path.join(weight_dir,weight_name )
# solver.load('weights/log01_Dink34.th')
# solver.load(weight_path)
tic = time()
tar=os.path.join('./submits/',weight_name[:-3])
if not os.path.exists(tar):
os.mkdir(tar)
target = os.path.join('./submits/',weight_name[:-3]+'/'+'test_iou/')
lower_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'lower_iou/')
higher_iou = os.path.join('./submits/',weight_name[:-3]+'/'+'higher_iou/')
if not os.path.exists(target):
os.mkdir(target)
if not os.path.exists(lower_iou):
os.mkdir(lower_iou)
if not os.path.exists(higher_iou):
os.mkdir(higher_iou)
test_pre_img_dir=os.path.join(target,'3/')
#wtn:精度计算
miou_mode = 2
#------------------------------#
# 分类个数+1、如2+1
#------------------------------#
num_classes = 2
#--------------------------------------------#
# 区分的种类,和json_to_dataset里面的一样
#--------------------------------------------#
# name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
name_classes = ["nonwater","water"]
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
data_path = './dataset/'
data_train_path='./dataset/'
f=open("./dataset/gt.txt", 'w')
gt_dir = os.path.join(data_path, "real/")
pred_dir = test_pre_img_dir
path_list = os.listdir(gt_dir)
path_list.sort()
dirList(gt_dir,path_list)
saveList(path_list)
image_ids = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines()
train_mIou=[]
train_mPA=[]
test_mIou=[]
test_mPA=[]
if miou_mode == 0 or miou_mode == 2:
mylog.write(str(weight_name[:-3]))
print(weight_name +" Get miou.")
print('计算测试miou')
test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name) # 执行计算mIoU的函数
mylog.write(' test_mIoU: '+str(test_miou))
mylog.write(' test_mPA: '+str(test_mpa))
print(' test_mIoU: '+str(test_miou))
# count=0
# print('计算测试样本单张iou')
# count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count) # 执行计算mIoU的函数
# mylog.write(' low-iou test picture num: '+str(count))
# print(weight_name + "Get miou done.")
mylog.write('Finish!')
print ('Finish!')
mylog.close()