基于U-Net的裂缝检测实现

基于U-Net的裂缝检测实现

本文介绍一个基于U-Net网络架构的裂缝检测项目(为了提高精度,添加了注意力机制),该项目可以对道路、建筑物等结构中的裂缝进行精确检测和分割,非常适合深度学习与计算机视觉领域的入门学习者。

项目GitHub地址:https://github.com/Ennan010/crack-detection-unet

项目概述

该项目使用改进型U-Net网络结构,融合了注意力机制,能够有效识别和分割图像中的裂缝区域。模型在CRACK500数据集上训练,可应用于道路、墙壁、桥梁等基础设施的裂缝检测。

在这里插入图片描述

在这里插入图片描述

环境配置

项目基于Python实现,依赖以下关键包:

scikit-learn>=0.23.2 
matplotlib==3.7.5
numpy==1.24.3
Pillow==10.3.0
torch>=2.0.0
torchvision>=0.15.0
tqdm==4.66.4

部署流程

1. 获取代码

git https://github.com/Ennan010/crack-detection-unet.git
cd crack_unet

2. 环境配置

创建并激活虚拟环境(可选):

# 使用conda
conda create -n crack_unet python=3.10
conda activate crack_unet

安装依赖:

pip install -r requirements.txt

3. 模型训练

本项目已经提供了预训练模型 best_model.pth,您可以直接使用。

https://pan.baidu.com/s/13PIct-k2FAQyei8ZlNnq-g?pwd=o675

如需重新训练模型或基于自定义数据集进行训练,执行以下命令:

python train.py

默认参数已对CRACK500数据集进行了优化,如需调整学习率、批次大小等参数,可直接修改train.py文件中的相关配置。

4. 模型推理

项目支持单张图像推理和批量推理两种模式:

单张图像推理

测试图像已预先存放在test_pic文件夹中,可直接用于模型测试:

python predict.py --image test_pic/road.jpg --model output_results/best_model.pth --output results

参数说明:

  • --image: 输入图像路径
  • --model: 模型权重文件路径(默认:output_results/best_model.pth)
  • --output: 输出目录(默认:results)
  • --threshold: 裂缝分割阈值,范围0-1(默认:0.5)
  • --no-postprocessing: 禁用后处理(默认启用)

输出解析

推理完成后,结果保存在results目录中:

  1. *_prediction.png: 包含原始图像、概率热力图和叠加显示的三联图
  2. *_mask.png: 二值化裂缝掩码图像

核心代码展示

下面展示项目中的核心代码文件,虽然代码中已经添加了详细注释,这里还是将完整代码放出来供大家参考:

模型结构 (unet_model.py)

import torch
import torch.nn as nn
from dataset import AttentionGate

"""
U-Net模型实现,增加了注意力机制以提高裂缝检测精度。
本模型是标准U-Net的改进版本,通过在解码器部分添加注意力门控模块,
使网络能够更好地关注裂缝相关特征,抑制背景干扰。
"""

class DoubleConv(nn.Module):
    """
    双重卷积模块:U-Net的基本构建块
    包含两个连续的3×3卷积层,每个卷积后接BatchNorm和ReLU激活
    这种组合可以增强特征提取能力并加速训练过程
    
    参数:
        in_channels: 输入通道数
        out_channels: 输出通道数
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    """
    改进型U-Net网络,用于裂缝检测
    
    主要改进:
    1. 在每个解码器层添加注意力门控(AttentionGate)
    2. 使用BatchNorm加速训练并提高稳定性
    
    参数:
        in_channels: 输入图像通道数,默认为3(RGB图像)
        out_channels: 输出通道数,默认为1(二元分割掩码)
    """
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        
        # 编码器路径:连续的下采样过程,每层特征通道数翻倍
        self.enc1 = DoubleConv(in_channels, 64)  # 第一层编码器
        self.enc2 = DoubleConv(64, 128)          # 第二层编码器
        self.enc3 = DoubleConv(128, 256)         # 第三层编码器
        self.enc4 = DoubleConv(256, 512)         # 第四层编码器
        
        # 网络最深层的瓶颈部分,具有最大的特征通道数
        self.bottleneck = DoubleConv(512, 1024)
        
        # 解码器路径:结合转置卷积上采样和注意力机制
        # 每层包括:
        # 1. 转置卷积上采样
        # 2. 注意力门控,增强相关特征
        # 3. 特征拼接
        # 4. 双重卷积处理
        
        # 第一层解码器(最深层)
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)  # 上采样
        self.att4 = AttentionGate(F_g=512, F_l=512, F_int=256)             # 注意力门控
        self.dec4 = DoubleConv(1024, 512)                                  # 处理拼接后的特征
        
        # 第二层解码器
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.att3 = AttentionGate(F_g=256, F_l=256, F_int=128)
        self.dec3 = DoubleConv(512, 256)
        
        # 第三层解码器
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.att2 = AttentionGate(F_g=128, F_l=128, F_int=64)
        self.dec2 = DoubleConv(256, 128)
        
        # 第四层解码器(最浅层)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.att1 = AttentionGate(F_g=64, F_l=64, F_int=32)
        self.dec1 = DoubleConv(128, 64)
        
        # 最终输出层:1x1卷积将特征图映射为所需的分割掩码
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
        # 下采样操作:使用最大池化
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        """
        前向传播过程:
        1. 编码器路径提取多尺度特征
        2. 瓶颈层捕获全局上下文
        3. 解码器路径结合注意力机制恢复空间细节
        """
        # 编码路径: 提取多尺度特征
        enc1 = self.enc1(x)                    # 最高分辨率特征
        enc2 = self.enc2(self.pool(enc1))      # 第一次下采样
        enc3 = self.enc3(self.pool(enc2))      # 第二次下采样
        enc4 = self.enc4(self.pool(enc3))      # 第三次下采样
        
        # 瓶颈: 最低分辨率,最高通道数
        bottleneck = self.bottleneck(self.pool(enc4))
        
        # 解码路径: 结合注意力机制的特征融合
        # 处理最深层特征
        dec4 = self.up4(bottleneck)            # 上采样瓶颈特征
        enc4_att = self.att4(dec4, enc4)       # 注意力处理编码器特征
        dec4 = torch.cat((dec4, enc4_att), dim=1)  # 拼接特征
        dec4 = self.dec4(dec4)                 # 处理拼接后的特征
        
        # 处理第三层特征
        dec3 = self.up3(dec4)
        enc3_att = self.att3(dec3, enc3)
        dec3 = torch.cat((dec3, enc3_att), dim=1)
        dec3 = self.dec3(dec3)
        
        # 处理第二层特征
        dec2 = self.up2(dec3)
        enc2_att = self.att2(dec2, enc2)
        dec2 = torch.cat((dec2, enc2_att), dim=1)
        dec2 = self.dec2(dec2)
        
        # 处理第一层特征(最浅层)
        dec1 = self.up1(dec2)
        enc1_att = self.att1(dec1, enc1)
        dec1 = torch.cat((dec1, enc1_att), dim=1)
        dec1 = self.dec1(dec1)
        
        # 最终输出: 生成裂缝分割掩码
        # 注意:需要在外部使用sigmoid函数将输出转换为概率
        return self.final_conv(dec1)

数据处理与注意力模块 (dataset.py)

import os
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageEnhance
import numpy as np
import random
import torch.nn as nn

"""
数据处理模块: 实现裂缝数据集的加载、预处理和增强
以及注意力门控模块的定义
"""

class CrackDataset(Dataset):
    """
    裂缝数据集加载类
    
    功能:
    1. 加载图像和对应的掩码
    2. 应用数据增强以提高模型泛化能力
    3. 提供统一的数据格式和预处理
    
    参数:
        image_dir: 原始图像所在目录
        mask_dir: 掩码图像所在目录
        transform: 图像变换(通常用于调整大小和转为张量)
        augment: 是否启用数据增强
    """
    def __init__(self, image_dir, mask_dir, transform=None, augment=False):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.augment = augment
        
        # 检查目录是否存在
        if not os.path.exists(image_dir):
            raise FileNotFoundError(f"图像目录不存在: {image_dir}")
        if not os.path.exists(mask_dir):
            raise FileNotFoundError(f"掩码目录不存在: {mask_dir}")
            
        # 获取所有图像文件名,只选择图像文件
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
        
        # 验证数据完整性:过滤掉没有对应掩码的图像
        valid_images = []
        for img in self.images:
            base_name = os.path.splitext(img)[0]
            mask_path = os.path.join(mask_dir, base_name + "_mask.png")
            if os.path.exists(mask_path):
                valid_images.append(img)
            
        self.images = valid_images
        print(f"找到 {len(self.images)} 个有效的图像-掩码对")
        
        # 打印前2个样本,帮助调试
        for i in range(min(2, len(self.images))):
            img_name = self.images[i]
            base_name = os.path.splitext(img_name)[0]
            mask_name = base_name + "_mask.png"
            print(f"图像 {i}: {img_name}")
            print(f"掩码 {i}: {mask_name}")
        
    def __len__(self):
        """返回数据集大小"""
        return len(self.images)
    
    def __getitem__(self, idx):
        """
        获取单个样本(图像和对应掩码),并应用数据增强
        
        注意: 
        - 图像和掩码必须进行相同的几何变换以保持对齐
        - 只对图像进行亮度和对比度调整
        - 处理异常情况,确保训练过程不中断
        """
        try:
            # 加载图像
            img_name = self.images[idx]
            img_path = os.path.join(self.image_dir, img_name)
            
            # 构建掩码路径 - 使用统一的命名规则 {base_name}_mask.png
            base_name = os.path.splitext(img_name)[0]
            mask_name = base_name + "_mask.png"
            mask_path = os.path.join(self.mask_dir, mask_name)
            
            # 读取图像和掩码,并转换为合适的格式
            image = Image.open(img_path).convert('RGB')  # 确保3通道
            mask = Image.open(mask_path).convert('L')    # 单通道灰度
            
            # 数据增强 (只在训练时随机应用)
            if self.augment and random.random() > 0.5:
                # 1. 随机旋转 - 裂缝可能以任何角度出现
                angle = random.choice([90, 180, 270])
                image = image.rotate(angle)
                mask = mask.rotate(angle)  # 掩码也需要相同旋转
                
                # 2. 随机水平翻转
                if random.random() > 0.5:
                    image = image.transpose(Image.FLIP_LEFT_RIGHT)
                    mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
                    
                # 3. 随机亮度、对比度调整(仅应用于图像,不应用于掩码)
                enhancer = ImageEnhance.Brightness(image)
                image = enhancer.enhance(random.uniform(0.8, 1.2))  # 亮度变化范围±20%
                
                enhancer = ImageEnhance.Contrast(image)
                image = enhancer.enhance(random.uniform(0.8, 1.2))  # 对比度变化范围±20%
            
            # 应用其他变换(尺寸调整、转为张量等)
            if self.transform:
                image = self.transform(image)
                mask = self.transform(mask)
            
            # 确保掩码是二值的(0或1)
            mask = (mask > 0.5).float()  # 大于0.5的像素视为裂缝(1),否则为背景(0)
            
            return image, mask
            
        except Exception as e:
            # 错误处理:防止单个样本错误导致整个训练停止
            print(f"处理索引 {idx} 的图像时出错: {e}")
            # 特殊情况:如果第一个样本就错误,创建零张量
            if idx == 0:
                image = torch.zeros((3, 256, 256))  # 创建空的RGB图像
                mask = torch.zeros((1, 256, 256))   # 创建空的掩码
                return image, mask
            # 否则尝试返回第一个样本
            return self.__getitem__(0)

class AttentionGate(nn.Module):
    """
    注意力门控模块
    
    这是本项目的核心改进部分,使网络能够关注裂缝区域并抑制背景噪声。
    工作原理:
    1. 分别处理来自上采样路径(g)和跳跃连接(x)的特征
    2. 计算特征间的相关性,生成注意力权重
    3. 用权重调整跳跃连接特征,突出重要区域
    
    参数:
        F_g: 上采样特征的通道数
        F_l: 跳跃连接特征的通道数
        F_int: 中间特征的通道数(降维用)
    """
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        # 处理上采样特征的卷积层
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),  # 1x1卷积降维
            nn.BatchNorm2d(F_int)                  # 批标准化提高稳定性
        )
        # 处理跳跃连接特征的卷积层
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),  # 1x1卷积降维
            nn.BatchNorm2d(F_int)                  # 批标准化
        )
        # 生成注意力图的卷积层
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),    # 输出单通道注意力图
            nn.BatchNorm2d(1),                     # 批标准化
            nn.Sigmoid()                           # 将值限制在0-1范围内
        )
        # 激活函数
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        """
        前向传播计算注意力权重并应用于特征
        
        参数:
            g: 上采样得到的特征(来自解码器)
            x: 跳跃连接特征(来自编码器)
            
        返回:
            x * psi: 加权后的跳跃连接特征
        """
        # 降维处理
        g1 = self.W_g(g)      # 处理上采样特征
        x1 = self.W_x(x)      # 处理跳跃连接特征
        
        # 特征融合和激活
        psi = self.relu(g1 + x1)  # 特征加和后ReLU激活
        
        # 生成注意力系数(0-1范围)
        psi = self.psi(psi)   # 生成注意力图
        
        # 注意力加权:相当于软掩码,突出重要区域
        return x * psi        # 将注意力应用到原始特征

网络架构改进与挑战

经过实验验证,标准U-Net网络结构在裂缝检测任务上表现较为一般,特别是对于细小且不规则的裂缝结构。为提升检测精度,本项目引入了多项改进策略:

  1. 注意力机制增强:在U-Net解码器部分的每一层引入注意力门控模块(Attention Gate),这些模块均基于开源实现,能够自适应地聚焦于裂缝特征,同时抑制背景噪声干扰
  2. 特征提取优化:增强上下文信息与局部特征的融合能力,提高对细微裂缝的识别精度

尽管这些改进显著提升了模型性能,但在复杂场景下(如纹理丰富的背景、光照不均匀区域或极细微裂缝)仍存在一定的检测挑战。当前版本在CRACK500标准数据集上取得了较为理想的性能,但在实际应用中可能需要针对特定场景进行进一步优化调整。

项目总结

本项目展示了如何运用深度学习技术解决结构裂缝检测问题。通过引入注意力机制对标准U-Net进行改进,显著提升了模型对细小裂缝的检测能力。虽然作为入门级实现,本方案在某些复杂场景下仍有优化空间,但其思路和实现方式为解决类似计算机视觉问题提供了良好的起点和参考。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值