如何使用PyTorch框架和U-Net架构来训练火车钢轨缺陷分割数据集,检测火车铁路轨道表面的腐蚀孔洞等缺陷区域分割任务

如何使用PyTorch框架和U-Net架构来训练火车钢轨缺陷分割数据集,检测火车铁路轨道表面的腐蚀孔洞等缺陷区域分割任务

火车钢轨缺陷分割数据集:在这里插入图片描述

用于检测火车轨道表面的腐蚀孔洞等缺陷区域分割任务。在这里插入图片描述
在这里插入图片描述
用于检测火车轨道表面腐蚀孔洞等缺陷区域分割任务的模型,适用于其他类型的分割任务,包括你提到的钢轨缺陷分割。

如何使用PyTorch框架和U-Net架构来训练这个数据集的详细步骤和代码示例。
文章及代码仅供参考。

1. 环境设置

首先确保你的环境中安装了必要的库:

pip install torch torchvision numpy matplotlib opencv-python

2. 数据准备

假设你的数据集已经分为训练集和验证集,并且每张图片都有对应的标注(即每个像素点属于背景或缺陷)。通常情况下,标注图像是灰度图像,其中0表示背景,1表示缺陷。

创建一个Dataset类来加载这些数据:

dataset.py
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class RailDefectDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None, mask_transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform
        self.images = os.listdir(img_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.png'))  # 根据实际情况调整文件名格式
        
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0  # 如果你的掩码图像是二值化后的,请根据实际情况调整
        
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        
        return image, mask

3. U-Net模型定义

接下来,我们需要定义U-Net模型。

unet_model.py
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        
        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
        
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = conv_block(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = conv_block(256, 128)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = conv_block(128, 64)
        self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder1 = conv_block(64 + 32, 32)
        self.final_conv = nn.Conv2d(32, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))

        dec4 = self.upconv4(enc4)
        dec4 = torch.cat((dec4, enc3), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc2), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc1), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.final_conv(dec1))

4. 训练脚本

编写训练脚本来训练U-Net模型。

train_unet.py
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn
from unet_model import UNet
from dataset import RailDefectDataset
from torchvision import transforms

def train_unet():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    mask_transform = transforms.Compose([transforms.ToTensor()])
    
    train_dataset = RailDefectDataset(img_dir='path/to/train/images/', mask_dir='path/to/train/masks/', 
                                      transform=transform, mask_transform=mask_transform)
    val_dataset = RailDefectDataset(img_dir='path/to/val/images/', mask_dir='path/to/val/masks/',
                                    transform=transform, mask_transform=mask_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    
    model = UNet(in_channels=3, out_channels=1).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    num_epochs = 20
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader.dataset)}")
        
        # Validation phase (optional)
        model.eval()
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                print(f"Validation Loss: {loss.item()}")

if __name__ == "__main__":
    train_unet()

我们需要使用适合语义分割的模型,如U-Net包括如何准备数据、定义模型、训练以及评估和可视化结果,特别针对火车轨道表面腐蚀孔洞等缺陷区域的分割任务。

1. 环境设置

首先,确保你的环境中安装了必要的库:

pip install torch torchvision numpy matplotlib opencv-python

2. 数据准备

假设你的数据集已经分为训练集和验证集,并且每张图片都有对应的标注(即每个像素点属于背景或缺陷)。通常情况下,标注图像是灰度图像,其中0表示背景,1表示缺陷。

创建一个Dataset类来加载这些数据:

dataset.py
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class RailDefectSegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None, mask_transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform
        self.images = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')]
        self.masks = [os.path.join(mask_dir, f.replace('.jpg', '_mask.png')) for f in os.listdir(img_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        mask_path = self.masks[idx]
        
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # 假设掩码是灰度图像
        
        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)
        
        return image, mask

3. U-Net模型定义

接下来,我们定义用于语义分割的U-Net模型。

unet_model.py
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

4. 训练脚本

编写训练脚本来训练U-Net模型。

train_unet.py
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn
from unet_model import UNet
from dataset import RailDefectSegmentationDataset
from torchvision import transforms
import os

def train_unet(dataset_dir, epochs=20, batch_size=4, lr=1e-4):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    mask_transform = transforms.ToTensor()

    train_dataset = RailDefectSegmentationDataset(os.path.join(dataset_dir, 'train/images/'),
                                                  os.path.join(dataset_dir, 'train/masks/'),
                                                  transform=transform, mask_transform=mask_transform)
    val_dataset = RailDefectSegmentationDataset(os.path.join(dataset_dir, 'val/images/'),
                                                os.path.join(dataset_dir, 'val/masks/'),
                                                transform=transform, mask_transform=mask_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = UNet(n_channels=3, n_classes=1).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.squeeze(1), masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader.dataset)}")

        # Validation phase (optional)
        model.eval()
        with torch.no_grad():
            for images, masks in val_loader:
                images = images.to(device)
                masks = masks.to(device)

                outputs = model(images)
                loss = criterion(outputs.squeeze(1), masks)
                print(f"Validation Loss: {loss.item()}")

    torch.save(model.state_dict(), os.path.join(dataset_dir, 'unet.pth'))

if __name__ == "__main__":
    dataset_dir = 'path/to/dataset/'  # 替换为你的数据集路径
    train_unet(dataset_dir)

5. 模型评估与预测结果可视化

在训练完成后,你可以通过以下脚本对模型进行评估并可视化预测结果。

evaluate_and_visualize.py
import torch
from unet_model import UNet
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms

def evaluate_and_visualize(model_path, image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = UNet(n_channels=3, n_classes=1).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_shape = image.shape[:2]
    tensor_image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(tensor_image).squeeze().cpu().numpy()
        output = (output > 0).astype(np.uint8) * 255  # 将输出二值化

    output_resized = cv2.resize(output, (original_shape[1], original_shape[0]))

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.title('Original Image')
    plt.imshow(image)
    plt.subplot(1, 2, 2)
    plt.title('Predicted Mask')
    plt.imshow(output_resized, cmap='gray')
    plt.show()

if __name__ == "__main__":
    model_path = 'path/to/unet.pth'  # 训练好的模型权重路径
    image_path = 'path/to/test/image.jpg'  # 测试图像路径
    evaluate_and_visualize(model_path, image_path)

以上代码提供了一个完整的从数据准备到模型训练、评估再到预测结果可视化的流程

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值