基于TransU-Net的遥感图像语义分割与分类 如何运行+训练?

在这里插入图片描述
在这里插入图片描述
基于TransU-Net的遥感图像语义分割与分类
项目背景
遥感图像语义分割与分类在城市规划、灾害监测、环境保护等领域具有重要应用。TransU-Net是一种结合了Transformer和U-Net架构的模型,能够有效捕捉长距离依赖关系和局部细节,适用于高分辨率遥感图像的语义分割任务。本项目旨在利用TransU-Net模型对不同城市的遥感建筑物数据集进行精准提取在这里插入图片描述
项目结构
深色版本
remote_sensing_building_segmentation/
├── data/
│ ├── train/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── val/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── test/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
├── models/
│ ├── transunet.py
│ ├── unet.py
│ ├── transformer.py
├── src/
│ ├── train.py
│ ├── predict.py
│ ├── utils.py
│ ├── dataset.py
├── weights/
│ └── best_model.pth
├── report/
│ ├── final_report.pdf
│ ├── figures/
│ │ ├── accuracy.png
│ │ ├── loss.png
│ │ └── …
├── requirements.txt
└── README.md

  1. 安装依赖
    首先,确保你已经安装了必要的库。创建一个requirements.txt文件,内容如下:在这里插入图片描述

深色版本
torch
torchvision
numpy
pandas
matplotlib
tqdm
albumentations
然后,使用以下命令安装依赖:
在这里插入图片描述

bash
深色版本
pip install -r requirements.txt
2. 数据集准备
确保你的数据集已经按照以下结构组织:

深色版本
remote_sensing_building_segmentation/
├── data/
│ ├── train/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── val/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ ├── test/
│ │ ├── images/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
│ │ ├── masks/
│ │ │ ├── 000001.png
│ │ │ ├── 000002.png
│ │ │ └── …
3. 数据集类
创建一个数据集类,用于加载和预处理数据。

3.1 src/dataset.py
python
深色版本
import os
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np

class RemoteSensingBuildingDataset(Dataset):
def init(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_files = os.listdir(image_dir)

def __len__(self):
    return len(self.image_files)

def __getitem__(self, index):
    image_file = self.image_files[index]
    image_path = os.path.join(self.image_dir, image_file)
    mask_path = os.path.join(self.mask_dir, image_file)

    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    if self.transform:
        augmented = self.transform(image=image, mask=mask)
        image = augmented['image']
        mask = augmented['mask']

    image = transforms.ToTensor()(image)
    mask = torch.tensor(mask, dtype=torch.long)

    return image, mask

def get_data_loaders(image_dir, mask_dir, batch_size=16, num_workers=4):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = RemoteSensingBuildingDataset(os.path.join(image_dir, 'train', 'images'), os.path.join(mask_dir, 'train', 'masks'), transform=transform)
val_dataset = RemoteSensingBuildingDataset(os.path.join(image_dir, 'val', 'images'), os.path.join(mask_dir, 'val', 'masks'), transform=transform)
test_dataset = RemoteSensingBuildingDataset(os.path.join(image_dir, 'test', 'images'), os.path.join(mask_dir, 'test', 'masks'), transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

return train_loader, val_loader, test_loader
  1. TransU-Net模型
    4.1 models/transunet.py
    python
    深色版本
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from models.unet import UNet
    from models.transformer import Transformer

class TransUNet(nn.Module):
def init(self, in_channels, num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True):
super(TransUNet, self).init()
self.unet = UNet(in_channels, num_classes)
self.transformer = Transformer(embed_dim, depth, num_heads, mlp_ratio, qkv_bias)

def forward(self, x):
    x = self.unet(x)
    x = self.transformer(x)
    return x

4.2 models/unet.py
python
深色版本
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
def init(self, in_channels, out_channels):
super(DoubleConv, self).init()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
    return self.conv(x)

class UNet(nn.Module):
def init(self, in_channels, num_classes):
super(UNet, self).init()
self.inc = DoubleConv(in_channels, 64)
self.down1 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(64, 128)
)
self.down2 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(128, 256)
)
self.down3 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(256, 512)
)
self.down4 = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(512, 1024)
)
self.up1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.conv1 = DoubleConv(1024, 512)
self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv2 = DoubleConv(512, 256)
self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv3 = DoubleConv(256, 128)
self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.conv4 = DoubleConv(128, 64)
self.outc = nn.Conv2d(64, num_classes, kernel_size=1)

def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)

    x = self.up1(x5)
    x = torch.cat([x, x4], dim=1)
    x = self.conv1(x)

    x = self.up2(x)
    x = torch.cat([x, x3], dim=1)
    x = self.conv2(x)

    x = self.up3(x)
    x = torch.cat([x, x2], dim=1)
    x = self.conv3(x)

    x = self.up4(x)
    x = torch.cat([x, x1], dim=1)
    x = self.conv4(x)

    logits = self.outc(x)
    return logits

4.3 models/transformer.py
python
深色版本
import torch
import torch.nn as nn
import torch.nn.functional as F

class Transformer(nn.Module):
def init(self, embed_dim, depth, num_heads, mlp_ratio, qkv_bias):
super(Transformer, self).init()
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias

    self.layers = nn.ModuleList([
        TransformerLayer(embed_dim, num_heads, mlp_ratio, qkv_bias)
        for _ in range(depth)
    ])

def forward(self, x):
    for layer in self.layers:
        x = layer(x)
    return x

class TransformerLayer(nn.Module):
def init(self, embed_dim, num_heads, mlp_ratio, qkv_bias):
super(TransformerLayer, self).init()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, bias=qkv_bias)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
)

def forward(self, x):
    x = x + self.attn(self.norm1(x))[0]
    x = x + self.mlp(self.norm2(x))
    return x
  1. 训练代码
    5.1 src/train.py
    python
    深色版本
    import torch
    import torch.optim as optim
    import torch.nn as nn
    from torch.utils.tensorboard import SummaryWriter
    from tqdm import tqdm
    from models.transunet import TransUNet
    from src.dataset import get_data_loaders
    import matplotlib.pyplot as plt

def train_transunet(data_dir, epochs=100, batch_size=16, learning_rate=1e-4):
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

model = TransUNet(in_channels=3, num_classes=2)
model = model.to(device)

train_loader, val_loader, _ = get_data_loaders(data_dir, data_dir, batch_size=batch_size)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

writer = SummaryWriter()

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    running_iou = 0.0

    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        iou = iou_score(outputs, masks)
        running_iou += iou

    train_loss = running_loss / len(train_loader)
    train_iou = running_iou / len(train_loader)
    writer.add_scalar('Training Loss', train_loss, epoch)
    writer.add_scalar('Training IoU', train_iou, epoch)

    model.eval()
    running_val_loss = 0.0
    running_val_iou = 0.0

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            running_val_loss += loss.item()
            iou = iou_score(outputs, masks)
            running_val_iou += iou

    val_loss = running_val_loss / len(val_loader)
    val_iou = running_val_iou / len(val_loader)
    writer.add_scalar('Validation Loss', val_loss, epoch)
    writer.add_scalar('Validation IoU', val_iou, epoch)

    print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Train IoU: {train_iou:.4f}, Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}")

torch.save(model.state_dict(), "weights/best_model.pth")
writer.close()

def iou_score(output, target):
smooth = 1e-6
output = torch.argmax(output, dim=1)
intersection = (output & target).float().sum((1, 2))
union = (output | target).float().sum((1, 2))
iou = (intersection + smooth) / (union + smooth)
return iou.mean().item()

if name == “main”:
data_dir = “data”
train_transunet(data_dir)
6. 模型评估
训练完成后,可以通过测试集来评估模型的性能。示例如下:

6.1 src/predict.py
python
深色版本
import torch
import matplotlib.pyplot as plt
from models.transunet import TransUNet
from src.dataset import get_data_loaders
import numpy as np

def predict_and_plot_transunet(data_dir, model_path, num_samples=5):
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

model = TransUNet(in_channels=3, num_classes=2)
model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()

_, _, test_loader = get_data_loaders(data_dir, data_dir, batch_size=1)

fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
with torch.no_grad():
    for i, (images, masks) in enumerate(test_loader):
        if i >= num_samples:
            break

        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        predicted = torch.argmax(outputs, dim=1)

        images = images.squeeze().cpu().numpy().transpose((1, 2, 0))
        masks = masks.squeeze().cpu().numpy()
        predicted = predicted.squeeze().cpu().numpy()

        ax = axes[i] if num_samples > 1 else axes
        ax[0].imshow(images)
        ax[0].set_title("Input Image")
        ax[0].axis('off')

        ax[1].imshow(masks, cmap='gray')
        ax[1].set_title("Ground Truth Mask")
        ax[1].axis('off')

        ax[2].imshow(predicted, cmap='gray')
        ax[2].set_title("Predicted Mask")
        ax[2].axis('off')

plt.tight_layout()
plt.show()

if name == “main”:
data_dir = “data”
model_path = “weights/best_model.pth”
predict_and_plot_transunet(data_dir, model_path)
7. 运行项目
确保你的数据集已经放在相应的文件夹中。
在项目根目录下运行以下命令启动训练:
bash
深色版本
python src/train.py
训练完成后,运行以下命令进行评估和可视化:
bash
深色版本
python src/predict.py
8. 报告
8.1 报告结构
摘要:简要介绍项目的背景、目标和主要成果。
引言:详细描述项目的背景、研究意义和相关工作。
数据集:介绍数据集的来源、结构和预处理方法。
方法:详细介绍TransU-Net模型的架构和训练过程。
实验:描述实验设置、训练参数和评估指标。
结果:展示实验结果,包括损失曲线、IoU分数和可视化结果。
讨论:分析实验结果,讨论模型的优势和不足,提出改进建议。
结论:总结项目的主要贡献和未来工作方向。
8.2 报告生成
使用LaTeX或Markdown编写报告。
将实验结果和可视化图表插入报告中。
生成PDF文件并存放在report文件夹中。
9. 功能说明
数据集类:RemoteSensingBuildingDataset类用于加载和预处理数据。
数据加载器:get_data_loaders函数用于创建训练、验证和测试数据加载器。
TransU-Net模型:transunet.py文件定义了TransU-Net模型。
训练脚本:train.py脚本用于训练TransU-Net模型。
预测脚本:predict.py脚本用于评估TransU-Net模型性能,并可视化输入图像、真实标签和预测结果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值