使用 U-Net 作为我们的基础模型 遥感智慧能源(煤炭,石油,太阳能等)发电厂遥感分割数据集 训练如何构建深度学习智慧能源发电厂分割系统
文章目录
遥感智慧能源(煤炭,石油,太阳能等)发电厂遥感分割数据集,4400余对遥感影像(每对影像包含1m分辨率与30m分辨率),19GB数据量,分割类别按照电厂燃料类别区分,共区分为煤炭,石油,天然气,其他化石燃料,核能,水利发电,太阳能,风能,地热能,废热,生物质11种类型,并分割出具体发电厂区域。
训练一个基于遥感影像的智慧能源发电厂分割模型,包括数据准备、环境搭建、数据预处理、模型选择与配置、训练过程以及性能评估等。
代码示例,仅供参考。
帮助你从零开始训练这个数据集。
1. 数据准备
文件结构
确保你的数据集文件结构如下:
energy_plant_segmentation/
├── images/
│ ├── 1m_resolution/
│ └── 30m_resolution/
└── masks/
├── 1m_resolution/
└── 30m_resolution/
每对影像(1m分辨率与30m分辨率)对应一个标注掩膜(mask),用于表示不同类型的电厂区域。
类别映射文件(classes.txt)
创建一个 classes.txt
文件,列出所有11个类别:
coal
oil
natural_gas
other_fossil_fuel
nuclear
hydro
solar
wind
geothermal
waste_heat
biomass
2. 环境搭建
安装必要的依赖包:
pip install torch torchvision torchaudio
git clone https://github.com/qubvel/segmentation_models.pytorch.git
cd segmentation_models.pytorch
pip install -r requirements.txt
pip install -e .
3. 数据预处理
数据集包含两种分辨率的图像(1m和30m),将使用1m分辨率的数据进行训练。对于不同的分辨率,可以考虑分别训练模型或者通过上采样/下采样统一分辨率。
创建自定义Dataset类
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
class EnergyPlantSegmentationDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(img_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.png'))
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask
4. 模型选择与配置
我们选择使用 U-Net 作为我们的基础模型,因为它在医学图像分割等领域表现出色,同样适用于遥感影像分割任务。
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=11, # model output channels (number of classes in your dataset)
)
5. 训练过程
数据加载器
transform = transforms.Compose([
transforms.Resize((256, 256)), # 调整尺寸以适应模型输入要求
transforms.ToTensor(),
])
train_dataset = EnergyPlantSegmentationDataset(img_dir='path/to/train/images', mask_dir='path/to/train/masks', transform=transform)
val_dataset = EnergyPlantSegmentationDataset(img_dir='path/to/val/images', mask_dir='path/to/val/masks', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
训练循环
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader.dataset)}")
6. 性能评估
在验证集上评估模型性能:
model.eval()
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in val_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the model on the validation images: {100 * correct / total}%')
7. 构建深度学习智慧能源发电厂分割系统
你可以进一步构建一个Web应用来展示预测结果,或者开发一个桌面应用程序来处理用户上传的遥感影像并输出分割结果。
jgck,仅供参考,