基于TransU-Net的遥感图像语义分割与分类
项目背景
遥感图像语义分割与分类在城市规划、灾害监测、环境保护等领域具有重要应用。TransU-Net是一种结合了Transformer和U-Net架构的模型,能够有效捕捉长距离依赖关系和局部细节,适用于高分辨率遥感图像的语义分割任务。本项目旨在利用TransU-Net模型对不同城市的遥感建筑物数据集进行精准提取
项目结构
深色版本
remote_sensing_building_segmentation/
├── data/
│ ├── train/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── val/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── test/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
├── models/
│ ├── transunet.py
│ ├── unet.py
│ ├── transformer.py
├── src/
│ ├── train.py
│ ├── predict.py
│ ├── utils.py
│ ├── dataset.py
├── weights/
│ └── best_model.pth
├── report/
│ ├── final_report.pdf
│ ├── figures/
│ │ ├── accuracy.png
│ │ ├── loss.png
│ │ └── …
├── requirements.txt
└── README.md
- 安装依赖
首先,确保你已经安装了必要的库。创建一个requirements.txt文件,内容如下:
深色版本
torch
torchvision
numpy
pandas
matplotlib
tqdm
albumentations
然后,使用以下命令安装依赖:
bash
深色版本
pip install -r requirements.txt
2. 数据集准备
确保你的数据集已经按照以下结构组织:
深色版本
remote_sensing_building_segmentation/
├── data/
│ ├── train/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── val/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── test/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
3. 数据集类
创建一个数据集类,用于加载和预处理数据。
3.1 src/dataset.py
python
深色版本
import os
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
class RemoteSensingBuildingDataset(Dataset):
def init(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_files = os.listdir(image_dir)
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
image_file = self.image_files[index]
image_path = os.path.join(self.image_dir, image_file)
mask_path = os.path.join(self.mask_dir, image_file)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
image = transforms.ToTensor()(image)
mask = torch.tensor(mask, dtype=torch.long)
return image, mask
def get_data_loaders(image_dir, mask_dir, batch_size=16, num_workers=4):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset = RemoteSensingBuildingDataset(os.path.join(image_dir, 'train', 'images'), os.path.join(mask_dir, 'train', 'masks'), transform=transform)
val_dataset = RemoteSensingBuildingDataset(os.path.join(image_dir, 'val', 'images'), os.path.join(mask_dir, 'val', 'masks'), transform=transform)
test_dataset = RemoteSensingBuildingDataset(os.path.join(image_dir, 'test', 'images'), os.path.join(mask_dir, 'test', 'masks'), transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, val_loader, test_loader
- TransU-Net模型
4.1 models/transunet.py
python
深色版本
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.unet import UNet
from models.transformer import Transformer
class TransUNet(nn.Module):
def init(self, in_channels, num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True):
super(TransUNet, self).init()
self.unet = UNet(in_channels, num_classes)
self.transformer = Transformer(embed_dim, depth, num_heads, mlp_ratio, qkv_bias)
def forward(self, x):
x = self.unet(x)
x = self.transformer(x)
return x
4.2 models/unet.py
python
深色版本
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def init(self, in_channels, out_channels):
super(DoubleConv, self).init()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def init(self, in_channels, num_classes):
super(UNet, self).init()
self.inc = DoubleConv(in_channels, 64)
self.down1 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(64, 128)
)
self.down2 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(128, 256)
)
self.down3 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(256, 512)
)
self.down4 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(512, 1024)
)
self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.conv1 = DoubleConv(1024, 512)
self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv2 = DoubleConv(512, 256)
self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv3 = DoubleConv(256, 128)
self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv4 = DoubleConv(128, 64)
self.outc = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5)
x = torch.cat([x, x4], dim=1)
x = self.conv1(x)
x = self.up2(x)
x = torch.cat([x, x3], dim=1)
x = self.conv2(x)
x = self.up3(x)
x = torch.cat([x, x2], dim=1)
x = self.conv3(x)
x = self.up4(x)
x = torch.cat([x, x1], dim=1)
x = self.conv4(x)
logits = self.outc(x)
return logits
4.3 models/transformer.py
python
深色版本
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transformer(nn.Module):
def init(self, embed_dim, depth, num_heads, mlp_ratio, qkv_bias):
super(Transformer, self).init()
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.layers = nn.ModuleList([
TransformerLayer(embed_dim, num_heads, mlp_ratio, qkv_bias)
for _ in range(depth)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class TransformerLayer(nn.Module):
def init(self, embed_dim, num_heads, mlp_ratio, qkv_bias):
super(TransformerLayer, self).init()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, bias=qkv_bias)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
)
def forward(self, x):
x = x + self.attn(self.norm1(x))[0]
x = x + self.mlp(self.norm2(x))
return x
- 训练代码
5.1 src/train.py
python
深色版本
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from models.transunet import TransUNet
from src.dataset import get_data_loaders
import matplotlib.pyplot as plt
def train_transunet(data_dir, epochs=100, batch_size=16, learning_rate=1e-4):
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
model = TransUNet(in_channels=3, num_classes=2)
model = model.to(device)
train_loader, val_loader, _ = get_data_loaders(data_dir, data_dir, batch_size=batch_size)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
writer = SummaryWriter()
for epoch in range(epochs):
model.train()
running_loss = 0.0
running_iou = 0.0
for images, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item()
iou = iou_score(outputs, masks)
running_iou += iou
train_loss = running_loss / len(train_loader)
train_iou = running_iou / len(train_loader)
writer.add_scalar('Training Loss', train_loss, epoch)
writer.add_scalar('Training IoU', train_iou, epoch)
model.eval()
running_val_loss = 0.0
running_val_iou = 0.0
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
running_val_loss += loss.item()
iou = iou_score(outputs, masks)
running_val_iou += iou
val_loss = running_val_loss / len(val_loader)
val_iou = running_val_iou / len(val_loader)
writer.add_scalar('Validation Loss', val_loss, epoch)
writer.add_scalar('Validation IoU', val_iou, epoch)
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}")
torch.save(model.state_dict(), "weights/best_model.pth")
writer.close()
def iou_score(output, target):
smooth = 1e-6
output = torch.argmax(output, dim=1)
intersection = (output & target).float().sum((1, 2))
union = (output | target).float().sum((1, 2))
iou = (intersection + smooth) / (union + smooth)
return iou.mean().item()
if name == “main”:
data_dir = “data”
train_transunet(data_dir)
6. 模型评估
训练完成后,可以通过测试集来评估模型的性能。示例如下:
6.1 src/predict.py
python
深色版本
import torch
import matplotlib.pyplot as plt
from models.transunet import TransUNet
from src.dataset import get_data_loaders
import numpy as np
def predict_and_plot_transunet(data_dir, model_path, num_samples=5):
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
model = TransUNet(in_channels=3, num_classes=2)
model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()
_, _, test_loader = get_data_loaders(data_dir, data_dir, batch_size=1)
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
with torch.no_grad():
for i, (images, masks) in enumerate(test_loader):
if i >= num_samples:
break
images, masks = images.to(device), masks.to(device)
outputs = model(images)
predicted = torch.argmax(outputs, dim=1)
images = images.squeeze().cpu().numpy().transpose((1, 2, 0))
masks = masks.squeeze().cpu().numpy()
predicted = predicted.squeeze().cpu().numpy()
ax = axes[i] if num_samples > 1 else axes
ax[0].imshow(images)
ax[0].set_title("Input Image")
ax[0].axis('off')
ax[1].imshow(masks, cmap='gray')
ax[1].set_title("Ground Truth Mask")
ax[1].axis('off')
ax[2].imshow(predicted, cmap='gray')
ax[2].set_title("Predicted Mask")
ax[2].axis('off')
plt.tight_layout()
plt.show()
if name == “main”:
data_dir = “data”
model_path = “weights/best_model.pth”
predict_and_plot_transunet(data_dir, model_path)
7. 运行项目
确保你的数据集已经放在相应的文件夹中。
在项目根目录下运行以下命令启动训练:
bash
深色版本
python src/train.py
训练完成后,运行以下命令进行评估和可视化:
bash
深色版本
python src/predict.py
8. 报告
8.1 报告结构
摘要:简要介绍项目的背景、目标和主要成果。
引言:详细描述项目的背景、研究意义和相关工作。
数据集:介绍数据集的来源、结构和预处理方法。
方法:详细介绍TransU-Net模型的架构和训练过程。
实验:描述实验设置、训练参数和评估指标。
结果:展示实验结果,包括损失曲线、IoU分数和可视化结果。
讨论:分析实验结果,讨论模型的优势和不足,提出改进建议。
结论:总结项目的主要贡献和未来工作方向。
8.2 报告生成
使用LaTeX或Markdown编写报告。
将实验结果和可视化图表插入报告中。
生成PDF文件并存放在report文件夹中。
9. 功能说明
数据集类:RemoteSensingBuildingDataset类用于加载和预处理数据。
数据加载器:get_data_loaders函数用于创建训练、验证和测试数据加载器。
TransU-Net模型:transunet.py文件定义了TransU-Net模型。
训练脚本:train.py脚本用于训练TransU-Net模型。
预测脚本:predict.py脚本用于评估TransU-Net模型性能,并可视化输入图像、真实标签和预测结果。