如何使用PyTorch框架和U-Net架构来训练火车钢轨缺陷分割数据集,检测火车铁路轨道表面的腐蚀孔洞等缺陷区域分割任务
文章目录
火车钢轨缺陷分割数据集:
用于检测火车轨道表面的腐蚀孔洞等缺陷区域分割任务。
用于检测火车轨道表面腐蚀孔洞等缺陷区域分割任务的模型,适用于其他类型的分割任务,包括你提到的钢轨缺陷分割。
如何使用PyTorch框架和U-Net架构来训练这个数据集的详细步骤和代码示例。
文章及代码仅供参考。
1. 环境设置
首先确保你的环境中安装了必要的库:
pip install torch torchvision numpy matplotlib opencv-python
2. 数据准备
假设你的数据集已经分为训练集和验证集,并且每张图片都有对应的标注(即每个像素点属于背景或缺陷)。通常情况下,标注图像是灰度图像,其中0表示背景,1表示缺陷。
创建一个Dataset
类来加载这些数据:
dataset.py
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class RailDefectDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None, mask_transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transform = transform
self.mask_transform = mask_transform
self.images = os.listdir(img_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.png')) # 根据实际情况调整文件名格式
image = np.array(Image.open(img_path).convert("RGB"))
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
mask[mask == 255.0] = 1.0 # 如果你的掩码图像是二值化后的,请根据实际情况调整
if self.transform:
image = self.transform(image)
if self.mask_transform:
mask = self.mask_transform(mask)
return image, mask
3. U-Net模型定义
接下来,我们需要定义U-Net模型。
unet_model.py
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1):
super(UNet, self).__init__()
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.encoder1 = conv_block(in_channels, 64)
self.encoder2 = conv_block(64, 128)
self.encoder3 = conv_block(128, 256)
self.encoder4 = conv_block(256, 512)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder4 = conv_block(512, 256)
self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder3 = conv_block(256, 128)
self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder2 = conv_block(128, 64)
self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.decoder1 = conv_block(64 + 32, 32)
self.final_conv = nn.Conv2d(32, out_channels, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool(enc1))
enc3 = self.encoder3(self.pool(enc2))
enc4 = self.encoder4(self.pool(enc3))
dec4 = self.upconv4(enc4)
dec4 = torch.cat((dec4, enc3), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc2), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc1), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.final_conv(dec1))
4. 训练脚本
编写训练脚本来训练U-Net模型。
train_unet.py
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn
from unet_model import UNet
from dataset import RailDefectDataset
from torchvision import transforms
def train_unet():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
mask_transform = transforms.Compose([transforms.ToTensor()])
train_dataset = RailDefectDataset(img_dir='path/to/train/images/', mask_dir='path/to/train/masks/',
transform=transform, mask_transform=mask_transform)
val_dataset = RailDefectDataset(img_dir='path/to/val/images/', mask_dir='path/to/val/masks/',
transform=transform, mask_transform=mask_transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
model = UNet(in_channels=3, out_channels=1).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 20
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader.dataset)}")
# Validation phase (optional)
model.eval()
with torch.no_grad():
for images, masks in val_loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
print(f"Validation Loss: {loss.item()}")
if __name__ == "__main__":
train_unet()
我们需要使用适合语义分割的模型,如U-Net包括如何准备数据、定义模型、训练以及评估和可视化结果,特别针对火车轨道表面腐蚀孔洞等缺陷区域的分割任务。
1. 环境设置
首先,确保你的环境中安装了必要的库:
pip install torch torchvision numpy matplotlib opencv-python
2. 数据准备
假设你的数据集已经分为训练集和验证集,并且每张图片都有对应的标注(即每个像素点属于背景或缺陷)。通常情况下,标注图像是灰度图像,其中0表示背景,1表示缺陷。
创建一个Dataset
类来加载这些数据:
dataset.py
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class RailDefectSegmentationDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None, mask_transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transform = transform
self.mask_transform = mask_transform
self.images = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')]
self.masks = [os.path.join(mask_dir, f.replace('.jpg', '_mask.png')) for f in os.listdir(img_dir) if f.endswith('.jpg')]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
mask_path = self.masks[idx]
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") # 假设掩码是灰度图像
if self.transform:
image = self.transform(image)
if self.mask_transform:
mask = self.mask_transform(mask)
return image, mask
3. U-Net模型定义
接下来,我们定义用于语义分割的U-Net模型。
unet_model.py
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
4. 训练脚本
编写训练脚本来训练U-Net模型。
train_unet.py
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn
from unet_model import UNet
from dataset import RailDefectSegmentationDataset
from torchvision import transforms
import os
def train_unet(dataset_dir, epochs=20, batch_size=4, lr=1e-4):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
mask_transform = transforms.ToTensor()
train_dataset = RailDefectSegmentationDataset(os.path.join(dataset_dir, 'train/images/'),
os.path.join(dataset_dir, 'train/masks/'),
transform=transform, mask_transform=mask_transform)
val_dataset = RailDefectSegmentationDataset(os.path.join(dataset_dir, 'val/images/'),
os.path.join(dataset_dir, 'val/masks/'),
transform=transform, mask_transform=mask_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
model = UNet(n_channels=3, n_classes=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
running_loss = 0.0
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs.squeeze(1), masks)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader.dataset)}")
# Validation phase (optional)
model.eval()
with torch.no_grad():
for images, masks in val_loader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs.squeeze(1), masks)
print(f"Validation Loss: {loss.item()}")
torch.save(model.state_dict(), os.path.join(dataset_dir, 'unet.pth'))
if __name__ == "__main__":
dataset_dir = 'path/to/dataset/' # 替换为你的数据集路径
train_unet(dataset_dir)
5. 模型评估与预测结果可视化
在训练完成后,你可以通过以下脚本对模型进行评估并可视化预测结果。
evaluate_and_visualize.py
import torch
from unet_model import UNet
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms
def evaluate_and_visualize(model_path, image_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=1).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_shape = image.shape[:2]
tensor_image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(tensor_image).squeeze().cpu().numpy()
output = (output > 0).astype(np.uint8) * 255 # 将输出二值化
output_resized = cv2.resize(output, (original_shape[1], original_shape[0]))
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.title('Predicted Mask')
plt.imshow(output_resized, cmap='gray')
plt.show()
if __name__ == "__main__":
model_path = 'path/to/unet.pth' # 训练好的模型权重路径
image_path = 'path/to/test/image.jpg' # 测试图像路径
evaluate_and_visualize(model_path, image_path)
以上代码提供了一个完整的从数据准备到模型训练、评估再到预测结果可视化的流程