U-Net网络学习算法,识别墙面裂缝
准备工作:
- Win10
- Git
- miniconda
- Pytorch
步骤
1. 下载裂缝数据集
Github原地址下载
2. 转换 mat 为 jpg
import os
from os.path import isdir
import numpy as np
from PIL import Image
from scipy import io
if __name__ == '__main__':
file_path = r"此处改为你自己下载后的 groundTruth 路径"
png_img_dir = r"此处改为你要把转换后 jpg 保存路径"
if not isdir(png_img_dir):
os.makedirs(png_img_dir)
image_path_lists = os.listdir(file_path)
images_path = []
for index in range(len(image_path_lists)):
image_file = os.path.join(file_path, image_path_lists[index])
images_path.append(image_file)
image_mat = io.loadmat(image_file)
segmentation_image = image_mat['groundTruth']['Segmentation'][0]
segmentation_image_array = np.array(segmentation_image[0])
image = Image.fromarray((segmentation_image_array - 1) * 255)
png_image_path = os.path.join(png_img_dir, "%s.jpg" % image_path_lists[index][0:3])
image.save(png_image_path)
3. U-Net 代码
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler, optimizer
import torchvision
import os, sys
import cv2 as cv
from torch.utils.data import DataLoader, sampler
class SegmentationDataset(object):
def __init__(self, image_dir, mask_dir):
self.images = []
self.masks = []
files = os.listdir(image_dir)
sfiles = os.listdir(mask_dir)
for i in range(len(sfiles)):
img_file = os.path.join(image_dir, files[i])
mask_file = os.path.join(mask_dir, sfiles[i])
self.images.append(img_file)
self.masks.append(mask_file)
def __len__(self):
return len(self.images)
def num_of_samples(self):
return len(self.images)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image_path = self.images[idx]
mask_path = self.masks[idx]
else:
image_path = self.images[idx]
mask_path = self.masks[idx]
img = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
mask = cv.imread(mask_path, cv.IMREAD_GRAYSCALE)
img = np.float32(img) / 255.0
img = np.expand_dims(img, 0)
mask[mask <= 128] = 0
mask[mask > 128] = 1
mask = np.expand_dims(mask, 0)
sample = {
'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask